雙目測距系列(七)monodepth2訓練前數據集準備過程的簡析

前言

上一篇文章說過,monodepth2模型有三種訓練方式。針對我們的雙目場景,準備使用stereo training方法。

monodepth2的訓練入口函數在train.py中,如下圖所示。

總共就2行代碼,第一行代碼(類Trainer的構造函數)主要是來初始化和數據集準備;第二行代碼(Trainer類的成員函數)是真正執行訓練過程。

下文將結合代碼講解數據集準備過程。

數據加載

在Train()構造函數中,首先會對Trainer類成員變量進行初始化。這裏會摘取重點部分進行講解。

1)

        self.num_scales = len(self.opt.scales)
        self.num_input_frames = len(self.opt.frame_ids)

代碼中的opt是對options.py中的參數parse得到的dict。其參數對應值可以通過運行train.py腳本時輸入參數來進行設置,如下所示。如果在運行train.py時沒有顯示指定參數值,那麼該參數就對應使用缺省值。

python train.py --frame_ids 0 --use_stereo

 回到代碼,因爲在運行train.py時沒有輸入scales參數,所以其爲缺省值[0,1,2,3],其含義是在encoder和decoder時進行4級縮小和放大的多尺度,其倍數分別對應爲1, 2, 4, 8。

frame_ids的缺省值爲[0,-1,1],這裏如果採用stereo training的話 要顯示輸入參數:--frame_ids 0,即當前圖片,而不考慮它的時間域上的上一幀和下一幀。

2)

 if self.opt.use_stereo:
            self.opt.frame_ids.append("s")

如果是stereo training,那麼需要顯示添加參數--use_stereo,這樣上面代碼if條件爲true, frame_ids就變成了["0", "s"]

3)接下來就到了數據加載部分

        datasets_dict = {"kitti": datasets.KITTIRAWDataset,
                         "kitti_odom": datasets.KITTIOdomDataset}
        self.dataset = datasets_dict[self.opt.dataset]

KITTI數據集有兩個子類型:KITTIRAW和KITTIOdom,monodepth使用的是前者,本系列四(https://blog.csdn.net/ltshan139/article/details/105794584)有專門對它進行說明。 

        fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")

        train_filenames = readlines(fpath.format("train"))
        val_filenames = readlines(fpath.format("val"))
        img_ext = '.png' if self.opt.png else '.jpg'

 上面第一行代碼來獲取train和valid的文件路徑:fpath。在monodepth2開源項目根目錄下有一個splits的子目錄,然後在它的下面又分了eigen, eigen_full和eigen_zhou等子目錄,最後每個子目錄下才帶有train_files.txt和val_files.txt。其目錄結構如下所示:

根據github上的readme,單目訓練時推薦用的是eigen_zhou,雙目用的是eigen_full。  

最後一行img_ext用來顯示告訴當前訓練和驗證樣本圖片的格式是png還是jpg。

        train_dataset = self.dataset(
            self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
        self.train_loader = DataLoader(
            train_dataset, self.opt.batch_size, True,
            num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
        val_dataset = self.dataset(
            self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
            self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
        self.val_loader = DataLoader(
            val_dataset, self.opt.batch_size, True,
            num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)

 上面的代碼就是真正數據加載部分。因爲train和valid數據加載原理一樣,而且DatalLoader是pytorch的API,沒啥好講的,所以這裏主要分析下train_dataset = self.dataset(...)的運行過程。

前面已經講過了 self.dataset=datasets.KITTIRAWDataset。調用self.dataset(...)實際上調用的是datasets.KITTIRAWDataset的構造函數,如下所示。

class KITTIRAWDataset(KITTIDataset):
    """KITTI dataset which loads the original velodyne depth maps for ground truth
    """
    def __init__(self, *args, **kwargs):
        super(KITTIRAWDataset, self).__init__(*args, **kwargs)

其構造函數只有一行代碼: super(KITTIRAWDataset, self).__init__(*args, **kwargs),實際上它會調用其父類KITTIDataset的構造函數,如下所示。

class KITTIDataset(MonoDataset):
    """Superclass for different types of KITTI dataset loaders
    """
    def __init__(self, *args, **kwargs):
        super(KITTIDataset, self).__init__(*args, **kwargs)
        。。。 。。。

 裏面的super函數又會調用KITTIDataset的父類MonoDataset的構造函數。

class MonoDataset(data.Dataset):
    """Superclass for monocular dataloaders

    Args:
        data_path
        filenames
        height
        width
        frame_idxs
        num_scales
        is_train
        img_ext
    """
    def __init__(self,
                 data_path,
                 filenames,
                 height,
                 width,
                 frame_idxs,
                 num_scales,
                 is_train=False,
                 img_ext='.jpg'):
        super(MonoDataset, self).__init__()

        self.data_path = data_path
        self.filenames = filenames
        self.height = height
        self.width = width
        self.num_scales = num_scales
        self.interp = Image.ANTIALIAS

        self.frame_idxs = frame_idxs

        self.is_train = is_train
        self.img_ext = img_ext

        self.loader = pil_loader
        self.to_tensor = transforms.ToTensor()
        。。。 。。。

注意,self.dataset(。。。)所帶的實參全部賦值給了MonoDataset(。。。),比如說data_path, filenames等。相當於把全部訓練和驗證樣本文件名拿到了,以便後面訓練時一個一個batch來從數據集裏面隨機抽取。

MonoDataset的構造函數運行完成後再回到KITTIDataset的構造函數剩餘部分執行。

 

 

 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章