双目测距系列(八)monodepth2训练代码分析上

前言

在系列七中,我们提到了train.py中实际上只有两行训练相关的代码,第一行是Trainer构造函数的调用,主要是初始化和数据集的构建,系列七主要是对这个过程进行了梳理。第二行是Trainer成员函数train的执行,这个是训练真正执行部分,本文着重来对它进行分析。

训练数据集按batch来加载

    def train(self):
        """Run the entire training pipeline
        """
        self.epoch = 0
        self.step = 0
        self.start_time = time.time()
        for self.epoch in range(self.opt.num_epochs):
            self.run_epoch()
            if (self.epoch + 1) % self.opt.save_frequency == 0:
                self.save_model()

上面是成员函数train的实现,主要部分在run_epoch()里面。 

    def run_epoch(self):
        """Run a single epoch of training and validation
        """
        #self.model_lr_scheduler.step()

        print("Training")
        self.set_train()

        for batch_idx, inputs in enumerate(self.train_loader):

            before_op_time = time.time()

            outputs, losses = self.process_batch(inputs)
            。。。 。。。

这个函数其实也很简单,首先通过set_train()来将resnet encoder和depth decoder模型设置成训练状态,然后通过enumrate(self.train_loader)来返回一个batch大小的inputs数据。

上文说过,每当枚举train_loader时,就会调用一次mono_dataset.py中的__getitem__()。这个函数较复杂些,但很重要。方便起见,我在代码里面添加了中文解释信息。

    def __getitem__(self, index):
        
        inputs = {}

        //随机做训练数据颜色增强预处理
        do_color_aug = self.is_train and random.random() > 0.5 
        //随机做训练数据水平左右flip预处理
        do_flip = self.is_train and random.random() > 0.5

        //index是train_txt中的第index行。
        line = self.filenames[index].split()
        //train_files.txt中一行数据的第一部分,即图片所在目录。
        folder = line[0]
        //每一行一般都为3个部分,第二个部分是图片的frame_index
        if len(line) == 3:
            frame_index = int(line[1])
        else:
            frame_index = 0
        
        //side为l或r,表明该图片是左或右摄像头所拍。
        if len(line) == 3:
            side = line[2]
        else:
            side = None

        //在stereo训练时, frame_idxs为["0","s"]
        //通过这个for循环,inputs[("color", "0", -1)]和inputs[("color", "s", -1)]
        //分别获得了frame_index和它对应的另外一个摄像头拍的图片数据。
        for i in self.frame_idxs:
            if i == "s":
                other_side = {"r": "l", "l": "r"}[side]
                inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
            else:
                inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)

        # adjusting intrinsics to match each scale in the pyramid
        //因为模型有4个尺度,所以对应4个相机内参
        for scale in range(self.num_scales):
            K = self.K.copy()

            K[0, :] *= self.width // (2 ** scale)
            K[1, :] *= self.height // (2 ** scale)

            inv_K = np.linalg.pinv(K)

            inputs[("K", scale)] = torch.from_numpy(K)
            inputs[("inv_K", scale)] = torch.from_numpy(inv_K)

        //颜色增强参数设定
        if do_color_aug:
            color_aug = transforms.ColorJitter.get_params(
                self.brightness, self.contrast, self.saturation, self.hue)
        else:
            color_aug = (lambda x: x)

        //训练前数据预处理以及对输入数据做多尺度resize。
        self.preprocess(inputs, color_aug)
        //经过preprocess,产生了inputs[("color","0", 0/1/23)]和inputs[("color_aug","0",         
        // 0/1/23)]。所以可以将原始的inputs[("color", i, -1)]和[("color_aug", i, -1)]释放
        for i in self.frame_idxs:
            del inputs[("color", i, -1)]
            del inputs[("color_aug", i, -1)]

        //load_depth为False,因为不需要GT label数据
        if self.load_depth:
            depth_gt = self.get_depth(folder, frame_index, side, do_flip)
            inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
            inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))

        //在stereo训练时,还需要构造双目姿态的平移矩阵参数inputs["stereo_T"]
        if "s" in self.frame_idxs:
            stereo_T = np.eye(4, dtype=np.float32)
            baseline_sign = -1 if do_flip else 1
            side_sign = -1 if side == "l" else 1
            stereo_T[0, 3] = side_sign * baseline_sign * 0.1

            inputs["stereo_T"] = torch.from_numpy(stereo_T)

        return inputs

开始处理 

 通过上面的枚举train_loader操作就可以得到各个尺度的inputs数据,然后作为参数输入到self.process_batch(inputs)。process_batch的返回值为ouputs和loss。这个函数执行完后,整个train就只剩下根据loss值backward来更新梯度,并根据优化器和lr来更新权值。

    def process_batch(self, inputs):
        """Pass a minibatch through the network and generate images and losses
        """
        for key, ipt in inputs.items():
            inputs[key] = ipt.to(self.device)

        if self.opt.pose_model_type == "shared":
            # If we are using a shared encoder for both depth and pose (as advocated
            # in monodepthv1), then all images are fed separately through the depth encoder.
            all_color_aug = torch.cat([inputs[("color_aug", i, 0)] for i in self.opt.frame_ids])
            all_features = self.models["encoder"](all_color_aug)
            all_features = [torch.split(f, self.opt.batch_size) for f in all_features]

            features = {}
            for i, k in enumerate(self.opt.frame_ids):
                features[k] = [f[i] for f in all_features]

            outputs = self.models["depth"](features[0])
        else:
            # Otherwise, we only feed the image with frame_id 0 through the depth encoder
            features = self.models["encoder"](inputs["color_aug", 0, 0])
            outputs = self.models["depth"](features)

        if self.opt.predictive_mask:
            outputs["predictive_mask"] = self.models["predictive_mask"](features)

        if self.use_pose_net:
            outputs.update(self.predict_poses(inputs, features))

        self.generate_images_pred(inputs, outputs)
        losses = self.compute_losses(inputs, outputs)

        return outputs, losses

 在上面的函数中,outputs是depth decoder求出来的,具体代码为:

features = self.models["encoder"](inputs["color_aug", 0, 0])和outputs = self.models["depth"](features)。

有了ouputs就可以来算loss,这个主要通过self.generate_images_pred(inputs, outputs)和losses = self.compute_losses(inputs, outputs)来实现。 细节将在下一篇文章来分析。


 

 


 

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章