PyTorch同時讀取兩個數據集實現半監督學習

PyTorch同時讀取兩個數據集實現半監督學習

寫在開頭

本文是在kaggle上做的實驗,所以直接從上面導了出來。後期應該還會更新,因爲沒寫完。。。

https://www.kaggle.com/lartpang/segmentationdataloader

UNLABELED_PATH = ["/kaggle/input/ecssd/ECSSD/Image", "/kaggle/input/ecssd/ECSSD/Mask"]
LABELED_PATH = ["/kaggle/input/pascal-s/Pascal-S/Image", "/kaggle/input/pascal-s/Pascal-S/Mask"]

TODO

  • 讀取DUTS-TR和MixFlickrDUS用於訓練
  • 每個batch都要保證包含1/4的DUTS-TR的數據集和3/4的MixFlickrDUTS
  • 針對訓練集使用不同的增強方式
  • 嘗試更多的方法

不考慮測試集,因爲測試集完全可以使用一個獨立的ImageFolder類構造。

方法一:通過對__getitem__的索引進行計算,按照比例關係選擇對應數據集的數據

if index % (self.r_l_rate + 1) == 0:
    label_index = index // (self.r_l_rate + 1)
    img_path, gt_path = self.imgs_label[label_index]  # 0, 1 => 10550
else:
    unlabel_index = index // (self.r_l_rate + 1) + index % (self.r_l_rate + 1)
    img_path, gt_path = self.imgs_unlabel[unlabel_index]  # 1, 2, 3

主體代碼:

import os

import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size參數請設置爲int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 僅針對訓練集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'isTrain參數錯誤,應該爲bool類型'
        self.root_labeled = root[0]
        self.mode = mode
        self.use_bigt = use_bigt
        
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]
        len_labeled = len(self.imgs_labeled)

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        len_unlabeled = len(self.imgs_unlabeled)

        len_unlabeled = self.r_l_rate * len_labeled
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 擴展無標籤的數據列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        self.length = len_labeled + len_unlabeled
        print(f"使用擴充比例爲:{len(self.imgs_labeled) / len(self.imgs_unlabeled)}")

        # 僅是爲了簡單而僅使用一種變換
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 處理的是Tensor
        ])
        # ToTensor 操作會將 PIL.Image 或形狀爲 H×W×D,數值範圍爲 [0, 255] 的 np.ndarray 轉換爲形狀爲 D×H×W,
        # 數值範圍爲 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        if index % (self.r_l_rate + 1) == 0:
            labeled_index = index // (self.r_l_rate + 1)
            img_path, gt_path = self.imgs_labeled[labeled_index]  # 0, 1 => 10550
        else:
            unlabeled_index = index // (self.r_l_rate + 1) + index % (self.r_l_rate + 1)
            img_path, gt_path = self.imgs_unlabeled[unlabeled_index]  # 1, 2, 3

        img = Image.open(img_path).convert('RGB')
        img_name = (img_path.split(os.sep)[-1]).split('.')[0]

        gt = Image.open(gt_path).convert('L')
        img, gt = self.train_joint_transform(img, gt)
        img = self.train_img_transform(img)
        gt = self.train_gt_transform(gt)
        if self.use_bigt:
            gt = gt.ge(0.5).float()  # 二值化
        return img, gt, img_name  # 輸出名字方便比較

    def __len__(self):
        return self.length
    
print(f" ==>> 使用的訓練集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(12, 36))
# 由於train_set內部的比例順序是固定的,所以爲了保持比例關係,不能再使用`shuffle=True`
train_loader = DataLoader(train_set, batch_size=48, num_workers=8, shuffle=False, drop_last=True, pin_memory=True)  

for train_idx, train_data in enumerate(train_loader):
    train_inputs, train_gts, train_names = train_data
    print(train_names)
    
    # 正常訓練中下面應該有,這裏爲了方便就關掉了
    # train_inputs = train_inputs.to(self.dev)
    # train_gts = train_gts.to(self.dev)
    train_labeled_inputs, train_unlabeled_inputs = train_inputs.split((12, 36), dim=0)
    train_labeled_gts, _ = train_gts.split((12, 36), dim=0)

    # otr_total = self.net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((12, 36), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一個Batch結束了 <<== ")
    if train_idx == 2:
        break
print(" ==>> 一個Epoch結束了 <<== ")
 ==>> 使用的訓練集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用擴充比例爲:0.3333333333333333
('513', '0321', '0864', '0692', '39', '0864', '0692', '0854', '747', '0692', '0854', '0821', '150', '0854', '0821', '0410', '364', '0821', '0410', '0041', '653', '0410', '0041', '0728', '199', '0041', '0728', '0133', '428', '0728', '0133', '0961', '146', '0133', '0961', '0990', '281', '0961', '0990', '0756', '129', '0990', '0756', '0099', '758', '0756', '0099', '0938')
 ==>> 一個Batch結束了 <<== 
('552', '0099', '0938', '0988', '373', '0938', '0988', '0085', '665', '0988', '0085', '0337', '445', '0085', '0337', '0531', '584', '0337', '0531', '0545', '366', '0531', '0545', '0254', '565', '0545', '0254', '0883', '165', '0254', '0883', '0878', '343', '0883', '0878', '0514', '221', '0878', '0514', '0572', '475', '0514', '0572', '0626', '470', '0572', '0626', '0827')
 ==>> 一個Batch結束了 <<== 
('666', '0626', '0827', '0688', '527', '0827', '0688', '0696', '838', '0688', '0696', '0192', '223', '0696', '0192', '0483', '557', '0192', '0483', '0910', '86', '0483', '0910', '0544', '673', '0910', '0544', '0183', '742', '0544', '0183', '0179', '71', '0183', '0179', '0458', '323', '0179', '0458', '0551', '735', '0458', '0551', '0952', '824', '0551', '0952', '0554')
 ==>> 一個Batch結束了 <<== 
 ==>> 一個Epoch結束了 <<== 

方法二:直接在__getitem__中一次性讀取最簡化比例數量的樣本

上面的用法雖然簡單,直接在一個ImageFolder中對數據進行組合,但是這樣會導致一個問題,訓練的時候無法使用shuffle=True設定,對於訓練並不完美。

除了這裏的設置方式,還有一種值得參考:在PoolNet的設置中,是直接對於每次迭代按照1:1的比例輸入,所以其在__getitem__中直接同時imread兩個數據集的圖像。雖然這樣比較簡單,但是卻也是直接有效。

下面仿寫一份。

import os

import torch.utils.data as data
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size參數請設置爲int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 僅針對訓練集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'isTrain參數錯誤,應該爲bool類型'
        self.mode = mode
        self.use_bigt = use_bigt
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]

        self.root_labeled = root[0]
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)

        len_labeled = len(self.imgs_labeled)
        self.length = len_labeled

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        
        len_unlabeled = self.r_l_rate * len_labeled
        
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 擴展無標籤的數據列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        print(f"使用比例爲:{len_labeled / len_unlabeled}")

        # 僅是爲了簡單而僅使用一種變換
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 處理的是Tensor
        ])
        # ToTensor 操作會將 PIL.Image 或形狀爲 H×W×D,數值範圍爲 [0, 255] 的 np.ndarray 轉換爲形狀爲 D×H×W,
        # 數值範圍爲 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        # 這裏一次性讀取最簡化比例數量的樣本,所有的樣本需要單獨處理
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  # 用於無標籤數據使用聯合調整函數的時候代替無標籤數據真值進行佔位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        data_labeled = [img_labeled, gt_labeled, img_labeled_name]
        
        data_unlabeled = [[], []]
        for idx_periter in range(self.r_l_rate):
            # 這裏不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index//self.r_l_rate+idx_periter]  # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)  # 這裏爲了使用那個聯合調整的轉換類,使用上面的target進行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            data_unlabeled[0].append(img_unlabeled)
            data_unlabeled[1].append(img_unlabeled_name)

        return data_labeled, data_unlabeled  # 輸出名字方便比較

    def __len__(self):
        return self.length
    
print(f" ==>> 使用的訓練集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(12, 36))
# 由於train_set內部的比例順序已經被固定到每一次iter中,所以可以使用`shuffle=True`
train_loader = DataLoader(train_set, batch_size=12, num_workers=8, shuffle=True, drop_last=False, pin_memory=True)  

for train_idx, train_data in enumerate(train_loader):
    data_labeled, data_unlabeled = train_data
    
    train_labeled_inputs, train_labeled_gts, train_labeled_names = data_labeled
    print(train_labeled_inputs.size(), train_labeled_gts.size(), train_labeled_names)
    
    train_unlabeled_inputs_list, train_unlabeled_names = data_unlabeled
    train_unlabeled_inputs = torch.cat(train_unlabeled_inputs_list, dim=0)
    print(train_unlabeled_inputs.size(), train_unlabeled_names)
    
    train_labeled_inputs_batchsize = train_labeled_inputs.size(0)
    train_unlabeled_inputs_batchsize = train_unlabeled_inputs.size(0)
    
    # 正常訓練中下面應該有,這裏爲了方便就關掉了,這裏之所以不先進行cat再進行to(dev),是爲了便於後面ema_model輸入的時候使用一個已經在gpu上的張量,免去了再次搬運的麻煩
    # train_labeled_inputs = train_labeled_inputs.to(dev)
    # train_unlabeled_inputs = train_unlabeled_inputs.to(dev)
    # train_gts = train_labeled_gts.to(self.dev)
    train_inputs = torch.cat([train_labeled_inputs, train_unlabeled_inputs], dim=0)

    # otr_total = net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((train_labeled_inputs_batchsize, train_unlabeled_inputs_batchsize), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一個Batch結束了 <<== ")
    if train_idx == 2:
        break
print(" ==>> 一個Epoch結束了 <<== ")
 ==>> 使用的訓練集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用比例爲:0.3333333333333333
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('299', '566', '138', '678', '700', '457', '266', '310', '810', '743', '469', '592')
torch.Size([36, 3, 320, 320]) [('0387', '0094', '0578', '0462', '0399', '0377', '0807', '0970', '0287', '0591', '0514', '0500'), ('0508', '0069', '0818', '0314', '0068', '0453', '0850', '0749', '0469', '0252', '0572', '0914'), ('0847', '0232', '0609', '0716', '0287', '0457', '0294', '0225', '0591', '0538', '0626', '0931')]
 ==>> 一個Batch結束了 <<== 
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('26', '771', '37', '814', '248', '389', '848', '3', '66', '153', '448', '227')
torch.Size([36, 3, 320, 320]) [('0322', '0464', '0972', '0734', '0043', '0800', '0483', '0807', '0029', '0425', '0976', '0741'), ('0054', '0527', '0683', '0694', '0612', '0390', '0910', '0850', '0548', '0260', '0335', '0406'), ('0761', '0586', '0936', '0501', '0073', '0381', '0544', '0294', '0007', '0633', '0505', '0322')]
 ==>> 一個Batch結束了 <<== 
torch.Size([12, 3, 320, 320]) torch.Size([12, 1, 320, 320]) ('805', '635', '739', '56', '80', '78', '496', '575', '359', '379', '55', '354')
torch.Size([36, 3, 320, 320]) [('0032', '0164', '0314', '0407', '0165', '0734', '0540', '0501', '0137', '0058', '0740', '0053'), ('0470', '0464', '0716', '0740', '0413', '0694', '0671', '0834', '0707', '0387', '0186', '0876'), ('0053', '0527', '0601', '0186', '0800', '0501', '0218', '0524', '0679', '0508', '0588', '0578')]
 ==>> 一個Batch結束了 <<== 
 ==>> 一個Epoch結束了 <<== 

補充

上面的操作中,也可以考慮將img_unlabeledimg_labeled直接按照比例放到一起,而真值部分僅是返回gt_labeled,同時img_unlabeled_nameimg_labeled_name一起返回,下面是例子:

    def __getitem__(self, index):
        # 這裏一次性讀取最簡化比例數量的樣本,所有的樣本需要單獨處理
        total_img, labeled_gt, total_name = [], [], []
        
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  # 用於無標籤數據使用聯合調整函數的時候代替無標籤數據真值進行佔位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        total_img.append(img_labeled)
        labeled_gt.append(gt_labeled)
        total_name.append(img_labeled_name)
        
        for idx_periter in range(self.r_l_rate):
            # 這裏不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index//self.r_l_rate+idx_periter]  # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)  # 這裏爲了使用那個聯合調整的轉換類,使用上面的target進行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            total_img.append(img_unlabeled)
            total_name.append(img_unlabeled_name)

        return total_img, labeled_gt, total_name  # 輸出名字方便比較

這樣在返回之後只需要對數據進行分割後處理即可,但是這裏的分割需要按照間隔分割,並不方便。

方法三:改造DataLoader

這一點主要受到了mean-teacher的啓發。

class TwoStreamBatchSampler(Sampler):
    """Iterate two sets of indices
    An 'epoch' is one iteration through the primary indices.
    During the epoch, the secondary indices are iterated through
    as many times as needed.
    """
    def __init__(self, primary_indices, secondary_indices, batch_size, secondary_batch_size):
        self.primary_indices = primary_indices
        self.secondary_indices = secondary_indices
        self.secondary_batch_size = secondary_batch_size
        self.primary_batch_size = batch_size - secondary_batch_size

        assert len(self.primary_indices) >= self.primary_batch_size > 0
        assert len(self.secondary_indices) >= self.secondary_batch_size > 0

    def __iter__(self):
        primary_iter = iterate_once(self.primary_indices)
        secondary_iter = iterate_eternally(self.secondary_indices)
        return (
            primary_batch + secondary_batch
            for (primary_batch, secondary_batch)
            in  zip(grouper(primary_iter, self.primary_batch_size),
                    grouper(secondary_iter, self.secondary_batch_size))
        )

    def __len__(self):
        return len(self.primary_indices) // self.primary_batch_size

調用的時候:

    dataset = torchvision.datasets.ImageFolder(traindir, train_transformation)

    if args.labels:
        with open(args.labels) as f:
            labels = dict(line.split(' ') for line in f.read().splitlines())
        labeled_idxs, unlabeled_idxs = data.relabel_dataset(dataset, labels)

    if args.exclude_unlabeled:
        sampler = SubsetRandomSampler(labeled_idxs)
        batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=True)
    elif args.labeled_batch_size:
        batch_sampler = data.TwoStreamBatchSampler(
            unlabeled_idxs, labeled_idxs, args.batch_size, args.labeled_batch_size)
    else:
        assert False, "labeled batch size {}".format(args.labeled_batch_size)

    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_sampler=batch_sampler,
                                               num_workers=args.workers,
                                               pin_memory=True)

這一部分需要分析下DataLoader的幾個參數。

參考資料:

  • https://blog.csdn.net/u014380165/article/details/79058479
  • pytorch學習筆記(十四): DataLoader源碼閱讀 https://blog.csdn.net/u012436149/article/details/78545766
  • Pytorch中的數據加載藝術 http://studyai.com/article/11efc2bf
class DataLoader(object):
    r"""
    Data loader. Combines a dataset and a sampler, and provides
    single- or multi-process iterators over the dataset.

    Arguments:
        dataset (Dataset): dataset from which to load the data. 
            自定義的Dataset類的子類,實現了基本的數據的讀取流程,例如獲取地址列表、根據索引打開圖片、
            圖片預處理等等
        batch_size (int, optional): how many samples per batch to load 
            (default: ``1``). 
            如字面含義,確定了batchsize,可知batch是對數個樣本的包裝
        shuffle (bool, optional): set to ``True`` to have the data reshuffled 
            at every epoch (default: ``False``). 
            是否每個週期都打亂數據的原始順序,一般是訓練的時候爲True,測試爲False
        sampler (Sampler, optional): defines the strategy to draw samples 
            from the dataset. If specified, ``shuffle`` must be False. 
            定義了從數據中採樣的策略,一次返回一個樣本的索引,這是Sampler的子類,此時必須關閉shuffle操作,
            相當於你得自己實現
        batch_sampler (Sampler, optional): like sampler, but returns a batch of
            indices at a time. Mutually exclusive with :attr:`batch_size`,
            :attr:`shuffle`, :attr:`sampler`, and :attr:`drop_last`.
            和sampler類似,但是這個跟更進一步,定義了針對batch級別的數據的採樣策略,與batch_size/
            shuffle/sampler/drop_last互斥,一次可以返回一個batch的索引
        num_workers (int, optional): how many subprocesses to use for data
            loading. 0 means that the data will be loaded in the main process.
            (default: ``0``)
            讀取數據使用的子進程數目,在一定程度上可以加快數據讀取
        collate_fn (callable, optional): merges a list of samples to form a mini-batch.
            是一個可調用的對象,用來合併樣本,構建mini-batch
        pin_memory (bool, optional): If ``True``, the data loader will copy tensors
            into CUDA pinned memory before returning them.  If your data elements
            are a custom type, or your ``collate_fn`` returns a batch that is a custom type
            see the example below.
            如果爲True,數據加載器在返回前將張量複製到CUDA固定內存中
        drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and
            the size of dataset is not divisible by the batch size, then the last batch
            will be smaller. (default: ``False``)
            是否丟棄每個週期最後一個不完整的batch,如果存在的話
        timeout (numeric, optional): if positive, the timeout value for collecting a batch
            from workers. Should always be non-negative. (default: ``0``)
            是一個非負值,來指定從workers中獲取數據的timeout參數,超過這個時間還沒讀取到數據的話就會報錯
        worker_init_fn (callable, optional): If not ``None``, this will be called on each
            worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
            input, after seeding and before data loading. (default: ``None``) 
            如果不是None,將在每個worker子進程上調用,使用worker id作爲輸入,在seeding之後以及數據加載之前
            (這個不太懂,目前還不理解用法)

    .. note:: When ``num_workers != 0``, the corresponding worker processes are created each time
              iterator for the DataLoader is obtained (as in when you call
              ``enumerate(dataloader,0)``).
              At this point, the dataset, ``collate_fn`` and ``worker_init_fn`` are passed to each
              worker, where they are used to access and initialize data based on the indices
              queued up from the main process. This means that dataset access together with
              its internal IO, transforms and collation runs in the worker, while any
              shuffle randomization is done in the main process which guides loading by assigning
              indices to load. Workers are shut down once the end of the iteration is reached.

              Since workers rely on Python multiprocessing, worker launch behavior is different
              on Windows compared to Unix. On Unix fork() is used as the default
              muliprocessing start method, so child workers typically can access the dataset and
              Python argument functions directly through the cloned address space. On Windows, another
              interpreter is launched which runs your main script, followed by the internal
              worker function that receives the dataset, collate_fn and other arguments
              through Pickle serialization.

              This separate serialization means that you should take two steps to ensure you
              are compatible with Windows while using workers
              (this also works equally well on Unix):

              - Wrap most of you main script's code within ``if __name__ == '__main__':`` block,
                to make sure it doesn't run again (most likely generating error) when each worker
                process is launched. You can place your dataset and DataLoader instance creation
                logic here, as it doesn't need to be re-executed in workers.
              - Make sure that ``collate_fn``, ``worker_init_fn`` or any custom dataset code
                is declared as a top level def, outside of that ``__main__`` check. This ensures
                they are available in workers as well
                (this is needed since functions are pickled as references only, not bytecode).

              By default, each worker will have its PyTorch seed set to
              ``base_seed + worker_id``, where ``base_seed`` is a long generated
              by main process using its RNG. However, seeds for other libraies
              may be duplicated upon initializing workers (w.g., NumPy), causing
              each worker to return identical random numbers. (See
              :ref:`dataloader-workers-random-seed` section in FAQ.) You may
              use :func:`torch.initial_seed()` to access the PyTorch seed for
              each worker in :attr:`worker_init_fn`, and use it to set other
              seeds before data loading.

    .. warning:: If ``spawn`` start method is used, :attr:`worker_init_fn` cannot be an
                 unpicklable object, e.g., a lambda function.

    The default memory pinning logic only recognizes Tensors and maps and iterables
    containg Tensors.  By default, if the pinning logic sees a batch that is a custom type
    (which will occur if you have a ``collate_fn`` that returns a custom batch type),
    or if each element of your batch is a custom type, the pinning logic will not
    recognize them, and it will return that batch (or those elements)
    without pinning the memory.  To enable memory pinning for custom batch or data types,
    define a ``pin_memory`` method on your custom type(s).
    默認的內存固定邏輯僅識別張量,包含張量的映射和迭代。
    默認情況下,如果固定邏輯看到一個自定義類型的批處理(如果您有一個返回自定義批處理類型的collate_fn,或者
    如果批處理的每個元素都是自定義類型,則會發生這種情況) 邏輯將無法識別它們,它將返回該批次(或那些元素)並且
    不固定內存。要爲自定義批處理或數據類型啓用內存固定,請在自定義類型上定義`pin_memory`方法。

    Example::

        class SimpleCustomBatch:
            def __init__(self, data):
                transposed_data = list(zip(*data))
                self.inp = torch.stack(transposed_data[0], 0)
                self.tgt = torch.stack(transposed_data[1], 0)

            def pin_memory(self):
                self.inp = self.inp.pin_memory()
                self.tgt = self.tgt.pin_memory()
                return self

        def collate_wrapper(batch):
            return SimpleCustomBatch(batch)

        inps = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        tgts = torch.arange(10 * 5, dtype=torch.float32).view(10, 5)
        dataset = TensorDataset(inps, tgts)

        loader = DataLoader(dataset, batch_size=2, collate_fn=collate_wrapper,
                            pin_memory=True)

        for batch_ndx, sample in enumerate(loader):
            print(sample.inp.is_pinned())
            print(sample.tgt.is_pinned())

    """

這裏主要關注參數中的samplerbatch_sampler以及collate_fn的用法。

samplerbatch_sampler

首先可以看默認要求是如何:

        # batch_sampler指定的時候,要求batch_size=1/shuule=False/sampler=None/drop_last=False
        # 也就是batch_sampler需要完成讀取並劃分batch、置亂數據、處理最後的batch等需求
        if batch_sampler is not None:
            if batch_size > 1 or shuffle or sampler is not None or drop_last:
                raise ValueError('batch_sampler option is mutually exclusive '
                                 'with batch_size, shuffle, sampler, and '
                                 'drop_last')
            self.batch_size = None
            self.drop_last = None

        # sampler指定的時候,要求shuffle=False,也就是sampler需要完成數據的獲取打亂的需求
        if sampler is not None and shuffle:
            raise ValueError('sampler option is mutually exclusive with '
                             'shuffle')

        if self.num_workers < 0:
            raise ValueError('num_workers option cannot be negative; '
                             'use num_workers=0 to disable multiprocessing.')

        # batch_sampler和sampler都沒有指定的時候,sampler根據shuffle來確定默認的設置爲
        # RandomSampler和SequentialSampler,可以看出來,一個是隨機抽取(所謂置亂)一個是
        # 按照順序抽取,而batch_sampler設置爲BatchSampler,所以說,若想要自己實現batch_sampler
        # 或者sampler,只要模仿這三個類即可
        if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)

        self.sampler = sampler
        self.batch_sampler = batch_sampler
        self.__initialized = True

簡言之,採樣器定義了索引(index)的產生規則,按指定規則去產生索引,從而控制數據的讀取機制(http://studyai.com/article/11efc2bf)

查看這幾個類,這裏的代碼來自V1.1.0

import torch
from torch._six import int_classes as _int_classes


class Sampler(object):
    r"""Base class for all Samplers.
    Every Sampler subclass has to provide an __iter__ method, providing a way
    to iterate over indices of dataset elements, and a __len__ method that
    returns the length of the returned iterators.
    
    每個Sampler的子類(後面的那些採集數據的類)都要包含下面這幾個方法
    """

    def __init__(self, data_source):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        # 在V1.2.0中,沒有了這個需求:https://github.com/pytorch/pytorch/blob/v1.2.0/torch/utils/data/sampler.py#L23-L48
        raise NotImplementedError


class SequentialSampler(Sampler):
    r"""Samples elements sequentially, always in the same order.
    Arguments:
        data_source (Dataset): dataset to sample from
        
    保證每個週期按照固定的順序讀取,所以這裏直接使用了range(len(self.data_source))作爲順序
    """

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)


class RandomSampler(Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify ``num_samples`` to draw.
    Arguments:
        data_source (Dataset): dataset to sample from
        replacement (bool): samples are drawn with replacement if ``True``, default=``False``
        num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
            is supposed to be specified only when `replacement` is ``True``.
            
    返回隨機打亂後的索引迭代器
    """

    def __init__(self, data_source, replacement=False, num_samples=None):
        self.data_source = data_source
        self.replacement = replacement
        self._num_samples = num_samples

        if not isinstance(self.replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(self.replacement))

        if self._num_samples is not None and not replacement:
            raise ValueError("With replacement=False, num_samples should not be specified, "
                             "since a random permute will be performed.")

        if not isinstance(self.num_samples, int) or self.num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(self.num_samples))

    # 關於該裝飾器:https://www.programiz.com/python-programming/property
    # 這裏爲私有屬性提供了一個接口
    @property
    def num_samples(self):
        # dataset size might change at runtime
        if self._num_samples is None:
            return len(self.data_source)
        return self._num_samples

    def __iter__(self):
        n = len(self.data_source)
        if self.replacement:
            # 這裏可以重新制定索引列表長度(=self.num_samples),索引列表最大值(=len(self.data_source)是固定的
            return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
        # torch.randperm(n) Returns a random permutation of integers from 0 to n - 1. 
        # https://pytorch.org/docs/1.1.0/torch.html#torch.randperm
        return iter(torch.randperm(n).tolist())

    def __len__(self):
        return self.num_samples


class SubsetRandomSampler(Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (sequence): a sequence of indices
        
    這裏是對於原有索引序列取出一個子集
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in torch.randperm(len(self.indices)))

    def __len__(self):
        return len(self.indices)


class WeightedRandomSampler(Sampler):
    r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).
    Args:
        weights (sequence)   : a sequence of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
            爲True的時候,可以理解爲有放回抽取,False可以理解爲無放回抽取
    Example:
        >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
        [0, 0, 0, 1, 0]
        >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
        [0, 1, 4, 3, 2]
        
    這裏根據對應的概率來採樣樣本,確定索引迭代器
    """

    def __init__(self, weights, num_samples, replacement=True):
        if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
                num_samples <= 0:
            raise ValueError("num_samples should be a positive integer "
                             "value, but got num_samples={}".format(num_samples))
        if not isinstance(replacement, bool):
            raise ValueError("replacement should be a boolean value, but got "
                             "replacement={}".format(replacement))
        self.weights = torch.as_tensor(weights, dtype=torch.double)
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        # torch.multinomial多項式分佈根據權重進行採樣:https://baike.baidu.com/item/%E5%A4%9A%E9%A1%B9%E5%88%86%E5%B8%83
        # https://pytorch.org/docs/1.1.0/torch.html#torch.multinomial
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

    def __len__(self):
        return self.num_samples

 
class BatchSampler(Sampler):
    r"""Wraps another sampler to yield a mini-batch of indices.
    Args:
        sampler (Sampler): Base sampler.
        batch_size (int): Size of mini-batch.
        drop_last (bool): If ``True``, the sampler will drop the last batch if
            its size would be less than ``batch_size``
    Example:
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
        >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
        [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
        
    BatchSampler 是基於 Sampler 來構造的: BatchSampler = Sampler + BatchSize
    """

    def __init__(self, sampler, batch_size, drop_last):
        if not isinstance(sampler, Sampler):
            raise ValueError("sampler should be an instance of "
                             "torch.utils.data.Sampler, but got sampler={}"
                             .format(sampler))
        if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        self.sampler = sampler
        self.batch_size = batch_size
        self.drop_last = drop_last

    def __iter__(self):
        # 這裏使用yield生成最終的迭代batch
        batch = []
        for idx in self.sampler:
            batch.append(idx)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        # 這裏判斷了最後一個可能存在的不完整的batch
        if len(batch) > 0 and not self.drop_last:
            yield batch

    def __len__(self):
        if self.drop_last:
            # 捨棄最後一個不完整的batch,向下取整
            return len(self.sampler) // self.batch_size
        else:
            # 若能整除,則self.batch_size-1整除後沒有影響,因爲結果爲0
            # 若是不能整除,則len(self.sampler)必然要比self.batch_size的整數倍多出[1, self.batch_size-1]的這個閉區間範圍的值,
            # 所以再加上一個該範圍最大的值self.batch_size-1必定會位於[len(self.sampler), (len(self.sampler)//self.batch_size+1)*self.batch_size]該區間內,結果正好多出來一個需要的(+1)
            return (len(self.sampler) + self.batch_size - 1) // self.batch_size

由上可見,Sampler本質就是個具有特定規則的可迭代對象,但只能單例迭代。

[x for x in range(10)], range(10)就是個最基本的Sampler,每次循環只能取出其中的一個值.

sampler = [x for x in range(10)]
print(f"原始Sampler:{sampler}")

from torch.utils.data.sampler import SequentialSampler
print(f"順序採樣:{[x for x in SequentialSampler(sampler)]}")

from torch.utils.data.sampler import RandomSampler
print(f"隨機置亂:{[x for x in RandomSampler(data_source=sampler, replacement=True, num_samples=5)]}")
原始Sampler:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
順序採樣:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
隨機置亂:[9, 4, 9, 6, 4]

collate_fn

參考資料:

  • https://jdhao.github.io/2017/10/23/pytorch-load-data-and-make-batch/#loading-variable-size-input-images
  • https://www.cnblogs.com/king-lps/p/10990304.html

查看源代碼(https://github.com/pytorch/pytorch/blob/v1.1.0/torch/utils/data/_utils/collate.py#L31):

def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size
    
    一般是輸入的batch中第一位爲圖像,第二位爲標籤,所以這裏直接判斷第一位的類型。第二位上也需要考慮是否可以被stack,
    對於分割任務而言,真值也是圖片,所以也得保證圖片有着相同的大小
    將batch中的數據進行整理,將一系列圖像和目標打包爲張量(張量的第一個維度爲批大小)。
    
    The default `collate_fn` expects all the images in a batch to have the same size 
    because it uses `torch.stack()` to pack the images. If the images provided by 
    Dataset have variable size, you have to provide your custom `collate_fn`.
    """

    elem_type = type(batch[0])
    if isinstance(batch[0], torch.Tensor):
        out = None
        if _use_shared_memory:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = batch[0].storage()._new_shared(numel)
            out = batch[0].new(storage)
        return torch.stack(batch, 0, out=out)
    
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(error_msg_fmt.format(elem.dtype))

            return default_collate([torch.from_numpy(b) for b in batch])
        if elem.shape == ():  # scalars
            py_type = float if elem.dtype.name.startswith('float') else int
            return numpy_type_map[elem.dtype.name](list(map(py_type, batch)))
        
    elif isinstance(batch[0], float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(batch[0], int_classes):
        return torch.tensor(batch)
    elif isinstance(batch[0], string_classes):
        return batch
    
    elif isinstance(batch[0], container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in batch[0]}
    elif isinstance(batch[0], tuple) and hasattr(batch[0], '_fields'):  # namedtuple
        return type(batch[0])(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(batch[0], container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError((error_msg_fmt.format(type(batch[0]))))

這裏個根據輸入的類型來實現對於不同類別的數據的返回與劃分。可見有幾處使用了遞歸的操作重複用了該函數。

import os

import torch.utils.data as data
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import math


class JointResize(object):
    def __init__(self, size):
        if isinstance(size, int):
            self.size = (size, size)
        elif isinstance(size, tuple):
            self.size = size
        else:
            raise RuntimeError("size參數請設置爲int或者tuple")

    def __call__(self, img, mask):
        img = img.resize(self.size)
        mask = mask.resize(self.size)
        return img, mask

def make_dataset(root, prefix=('jpg', 'png')):
    img_path = root[0]
    gt_path = root[1]
    img_list = [os.path.splitext(f)[0] for f in os.listdir(img_path) if f.endswith(prefix[0])]
    return [(os.path.join(img_path, img_name + prefix[0]), os.path.join(gt_path, img_name + prefix[1])) for img_name in img_list]


# 僅針對訓練集
class ImageFolder(data.Dataset):
    def __init__(self, root, mode, in_size, prefix, use_bigt=False, split_rate=(1, 3)):
        """split_rate = label:unlabel"""
        assert isinstance(mode, str), 'isTrain參數錯誤,應該爲bool類型'
        self.mode = mode
        self.use_bigt = use_bigt
        self.split_rate = split_rate
        self.r_l_rate = split_rate[1] // split_rate[0]

        self.root_labeled = root[0]
        self.imgs_labeled = make_dataset(self.root_labeled, prefix=prefix)

        len_labeled = len(self.imgs_labeled)
        self.length = len_labeled

        self.root_unlabeled = root[1]
        self.imgs_unlabeled = make_dataset(self.root_unlabeled, prefix=prefix)
        
        len_unlabeled = self.r_l_rate * len_labeled
        
        self.imgs_unlabeled = self.imgs_unlabeled * (self.r_l_rate + math.ceil(len_labeled / len_unlabeled))  # 擴展無標籤的數據列表
        self.imgs_unlabeled = self.imgs_unlabeled[0:len_unlabeled]

        print(f"使用比例爲:{len_labeled / len_unlabeled}")

        # 僅是爲了簡單而僅使用一種變換
        self.train_joint_transform = JointResize(in_size)
        self.train_img_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 處理的是Tensor
        ])
        # ToTensor 操作會將 PIL.Image 或形狀爲 H×W×D,數值範圍爲 [0, 255] 的 np.ndarray 轉換爲形狀爲 D×H×W,
        # 數值範圍爲 [0.0, 1.0] 的 torch.Tensor。
        self.train_gt_transform = transforms.ToTensor()

    def __getitem__(self, index):
        # 這裏一次性讀取最簡化比例數量的樣本,所有的樣本需要單獨處理
        img_labeled_path, gt_labeled_path = self.imgs_labeled[index]  # 0, 1 => 850
        img_labeled = Image.open(img_labeled_path).convert('RGB')
        img_labeled_name = (img_labeled_path.split(os.sep)[-1]).split('.')[0]

        gt_labeled = Image.open(gt_labeled_path).convert('L')
        back_gt_labeled = gt_labeled  # 用於無標籤數據使用聯合調整函數的時候代替無標籤數據真值進行佔位
        img_labeled, gt_labeled = self.train_joint_transform(img_labeled, gt_labeled)
        img_labeled = self.train_img_transform(img_labeled)
        gt_labeled = self.train_gt_transform(gt_labeled)
        if self.use_bigt:
            gt_labeled = gt_labeled.ge(0.5).float()  # 二值化
        data_labeled = [img_labeled, gt_labeled, img_labeled_name]
        
        data_unlabeled = [[], []]
        for idx_periter in range(self.r_l_rate):
            # 這裏不再使用真值,直接使用`_`接收
            img_unlabeled_path, _ = self.imgs_unlabeled[index//self.r_l_rate+idx_periter]  # 0, 1, 2, 3 => 3*850
            img_unlabeled = Image.open(img_unlabeled_path).convert('RGB')
            img_unlabeled_name = (img_unlabeled_path.split(os.sep)[-1]).split('.')[0]

            img_unlabeled, _ = self.train_joint_transform(img_unlabeled, back_gt_labeled)  # 這裏爲了使用那個聯合調整的轉換類,使用上面的target進行替代,但是要注意,不要再返回了
            img_unlabeled = self.train_img_transform(img_unlabeled)
                        
            data_unlabeled[0].append(img_unlabeled)
            data_unlabeled[1].append(img_unlabeled_name)

        return data_labeled, data_unlabeled  # 輸出名字方便比較

    def __len__(self):
        return self.length
    
    
def my_collate(batch):
    # 針對送進來的一個batch的數據進行整合,batch的各項表示各個樣本
    # batch 僅有一項 batch[0] 對應於下面的 train_data
    # batch[0][0], batch[0][1] <==> data_labeled, data_unlabeled = train_data
    # batch[0][0][0], batch[0][0][1], batch[0][0][2] <==> train_labeled_inputs, train_labeled_gts, train_labeled_names = data_labeled
    # batch[0][1][0], batch[0][2][1] <==> train_unlabeled_inputs_list, train_unlabeled_names = data_unlabeled
    
    # 最直接的方法:
    train_labeled_inputs, train_labeled_gts, train_labeled_names = [], [], []
    train_unlabeled_inputs_list, train_unlabeled_names = [], []
    for batch_iter in batch:
        x, y = batch_iter
        train_labeled_inputs.append(x[0])
        train_labeled_gts.append(x[1])
        train_labeled_names.append(x[2])
        
        train_unlabeled_inputs_list += y[0]
        train_unlabeled_names += y[1]

    train_labeled_inputs = torch.stack(train_labeled_inputs, 0)
    train_unlabeled_inputs_list = torch.stack(train_unlabeled_inputs_list, 0)
    train_labeled_gts = torch.stack(train_labeled_gts, 0)
    print(train_unlabeled_inputs_list.size())
    return ([train_labeled_inputs, train_unlabeled_inputs_list], [train_labeled_gts],
            [train_labeled_names, train_unlabeled_names])

print(f" ==>> 使用的訓練集 <<==\n -->> LABELED_PATH:{LABELED_PATH}\n -->> UNLABELED_PATH:{UNLABELED_PATH}")
train_set = ImageFolder((LABELED_PATH, UNLABELED_PATH), "train", 320, prefix=('.jpg', '.png'), use_bigt=True, split_rate=(3, 9))
# a simple custom collate function, just to show the idea
train_loader = DataLoader(train_set, batch_size=3, num_workers=4, collate_fn=my_collate, shuffle=True, drop_last=False, pin_memory=True)
print(" ==>> data_loader構建完畢 <<==")

for train_idx, train_data in enumerate(train_loader):

    train_inputs, train_gts, train_names = train_data
    
    train_labeled_inputs, train_unlabeled_inputs = train_inputs
    train_labeled_gts = train_gts[0]
    train_labeled_names, train_unlabeled_names = train_names
    print("-->>", train_labeled_inputs.size(), train_labeled_gts.size(), train_labeled_names)
    print("-->>", train_unlabeled_inputs.size(), train_unlabeled_names)
    
    train_labeled_inputs_batchsize = train_labeled_inputs.size(0)
    train_unlabeled_inputs_batchsize = train_unlabeled_inputs.size(0)
    
    # 正常訓練中下面應該有,這裏爲了方便就關掉了,這裏之所以不先進行cat再進行to(dev),是爲了便於後面ema_model輸入的時候使用一個已經在gpu上的張量,免去了再次搬運的麻煩
    # train_labeled_inputs = train_labeled_inputs.to(dev)
    # train_unlabeled_inputs = train_unlabeled_inputs.to(dev)
    # train_gts = train_labeled_gts.to(self.dev)
    train_inputs = torch.cat([train_labeled_inputs, train_unlabeled_inputs], dim=0)

    # otr_total = net(train_inputs)
    # labeled_otr, unlabeled_otr = otr_total.split((train_labeled_inputs_batchsize, train_unlabeled_inputs_batchsize), dim=0)
    # with torch.no_grad():
    #     ema_unlabeled_otr = ema_model(train_unlabeled_inputs)
    print(" ==>> 一個Batch結束了 <<== ")
    if train_idx == 0:
        break
print(" ==>> 一個Epoch結束了 <<== ")
 ==>> 使用的訓練集 <<==
 -->> LABELED_PATH:['/kaggle/input/pascal-s/Pascal-S/Image', '/kaggle/input/pascal-s/Pascal-S/Mask']
 -->> UNLABELED_PATH:['/kaggle/input/ecssd/ECSSD/Image', '/kaggle/input/ecssd/ECSSD/Mask']
使用比例爲:0.3333333333333333
 ==>> data_loader構建完畢 <<==
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
torch.Size([9, 3, 320, 320])
-->> torch.Size([3, 3, 320, 320]) torch.Size([3, 1, 320, 320]) ['783', '5', '116']
-->> torch.Size([9, 3, 320, 320]) ['0817', '0128', '0743', '0214', '0763', '0344', '0818', '0609', '0809']
 ==>> 一個Batch結束了 <<== 
 ==>> 一個Epoch結束了 <<== 

More

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章