前言
上一篇文章說過,monodepth2模型有三種訓練方式。針對我們的雙目場景,準備使用stereo training方法。
monodepth2的訓練入口函數在train.py中,如下圖所示。
總共就2行代碼,第一行代碼(類Trainer的構造函數)主要是來初始化和數據集準備;第二行代碼(Trainer類的成員函數)是真正執行訓練過程。
下文將結合代碼講解數據集準備過程。
數據加載
在Train()構造函數中,首先會對Trainer類成員變量進行初始化。這裏會摘取重點部分進行講解。
1)
self.num_scales = len(self.opt.scales)
self.num_input_frames = len(self.opt.frame_ids)
代碼中的opt是對options.py中的參數parse得到的dict。其參數對應值可以通過運行train.py腳本時輸入參數來進行設置,如下所示。如果在運行train.py時沒有顯示指定參數值,那麼該參數就對應使用缺省值。
python train.py --frame_ids 0 --use_stereo
回到代碼,因爲在運行train.py時沒有輸入scales參數,所以其爲缺省值[0,1,2,3],其含義是在encoder和decoder時進行4級縮小和放大的多尺度,其倍數分別對應爲1, 2, 4, 8。
frame_ids的缺省值爲[0,-1,1],這裏如果採用stereo training的話 要顯示輸入參數:--frame_ids 0,即當前圖片,而不考慮它的時間域上的上一幀和下一幀。
2)
if self.opt.use_stereo:
self.opt.frame_ids.append("s")
如果是stereo training,那麼需要顯示添加參數--use_stereo,這樣上面代碼if條件爲true, frame_ids就變成了["0", "s"]
3)接下來就到了數據加載部分
datasets_dict = {"kitti": datasets.KITTIRAWDataset,
"kitti_odom": datasets.KITTIOdomDataset}
self.dataset = datasets_dict[self.opt.dataset]
KITTI數據集有兩個子類型:KITTIRAW和KITTIOdom,monodepth使用的是前者,本系列四(https://blog.csdn.net/ltshan139/article/details/105794584)有專門對它進行說明。
fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt")
train_filenames = readlines(fpath.format("train"))
val_filenames = readlines(fpath.format("val"))
img_ext = '.png' if self.opt.png else '.jpg'
上面第一行代碼來獲取train和valid的文件路徑:fpath。在monodepth2開源項目根目錄下有一個splits的子目錄,然後在它的下面又分了eigen, eigen_full和eigen_zhou等子目錄,最後每個子目錄下才帶有train_files.txt和val_files.txt。其目錄結構如下所示:
根據github上的readme,單目訓練時推薦用的是eigen_zhou,雙目用的是eigen_full。
最後一行img_ext用來顯示告訴當前訓練和驗證樣本圖片的格式是png還是jpg。
train_dataset = self.dataset(
self.opt.data_path, train_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=True, img_ext=img_ext)
self.train_loader = DataLoader(
train_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
val_dataset = self.dataset(
self.opt.data_path, val_filenames, self.opt.height, self.opt.width,
self.opt.frame_ids, 4, is_train=False, img_ext=img_ext)
self.val_loader = DataLoader(
val_dataset, self.opt.batch_size, True,
num_workers=self.opt.num_workers, pin_memory=True, drop_last=True)
上面的代碼就是真正數據加載部分。因爲train和valid數據加載原理一樣,而且DatalLoader是pytorch的API,沒啥好講的,所以這裏主要分析下train_dataset = self.dataset(...)的運行過程。
前面已經講過了 self.dataset=datasets.KITTIRAWDataset。調用self.dataset(...)實際上調用的是datasets.KITTIRAWDataset的構造函數,如下所示。
class KITTIRAWDataset(KITTIDataset):
"""KITTI dataset which loads the original velodyne depth maps for ground truth
"""
def __init__(self, *args, **kwargs):
super(KITTIRAWDataset, self).__init__(*args, **kwargs)
其構造函數只有一行代碼: super(KITTIRAWDataset, self).__init__(*args, **kwargs),實際上它會調用其父類KITTIDataset的構造函數,如下所示。
class KITTIDataset(MonoDataset):
"""Superclass for different types of KITTI dataset loaders
"""
def __init__(self, *args, **kwargs):
super(KITTIDataset, self).__init__(*args, **kwargs)
。。。 。。。
裏面的super函數又會調用KITTIDataset的父類MonoDataset的構造函數。
class MonoDataset(data.Dataset):
"""Superclass for monocular dataloaders
Args:
data_path
filenames
height
width
frame_idxs
num_scales
is_train
img_ext
"""
def __init__(self,
data_path,
filenames,
height,
width,
frame_idxs,
num_scales,
is_train=False,
img_ext='.jpg'):
super(MonoDataset, self).__init__()
self.data_path = data_path
self.filenames = filenames
self.height = height
self.width = width
self.num_scales = num_scales
self.interp = Image.ANTIALIAS
self.frame_idxs = frame_idxs
self.is_train = is_train
self.img_ext = img_ext
self.loader = pil_loader
self.to_tensor = transforms.ToTensor()
。。。 。。。
注意,self.dataset(。。。)所帶的實參全部賦值給了MonoDataset(。。。),比如說data_path, filenames等。相當於把全部訓練和驗證樣本文件名拿到了,以便後面訓練時一個一個batch來從數據集裏面隨機抽取。
MonoDataset的構造函數運行完成後再回到KITTIDataset的構造函數剩餘部分執行。