Pytorch|YOWO原理及代碼詳解(三)

Pytorch|YOWO原理及代碼詳解(三)

本博客上接,
Pytorch|YOWO原理及代碼詳解(一)
Pytorch|YOWO原理及代碼詳解(二),閱前可看。

1. test分析

    if opt.evaluate:
        logging('evaluating ...')
        test(0)
    else:
        for epoch in range(opt.begin_epoch, opt.end_epoch + 1):
            # Train the model for 1 epoch
            train(epoch)

            # Validate the model
            fscore = test(epoch)

            is_best = fscore > best_fscore
            if is_best:
                print("New best fscore is achieved: ", fscore)
                print("Previous fscore was: ", best_fscore)
                best_fscore = fscore

            # Save the model to backup directory
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'fscore': fscore
            }
            save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
            logging('Weights are saved to backup directory: %s' % (backupdir))

上一節把train的整個流程分析完畢,本節主要分析test流程:fscore = test(epoch),進入(step into),查看完整的代碼如下:

    def test(epoch):
        def truths_length(truths):
            for i in range(50):
                if truths[i][1] == 0:
                    return i

        test_loader = torch.utils.data.DataLoader(
            dataset.listDataset(basepath, testlist, dataset_use=dataset_use, shape=(init_width, init_height),
                                shuffle=False,
                                transform=transforms.Compose([
                                    transforms.ToTensor()
                                ]), train=False),
            batch_size=batch_size, shuffle=False, **kwargs)

        num_classes = region_loss.num_classes
        anchors = region_loss.anchors
        num_anchors = region_loss.num_anchors
        conf_thresh_valid = 0.005
        total = 0.0
        proposals = 0.0
        correct = 0.0
        fscore = 0.0

        correct_classification = 0.0
        total_detected = 0.0

        nbatch = file_lines(testlist) // batch_size

        logging('validation at epoch %d' % (epoch))
        model.eval()

        for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
            if use_cuda:
                data = data.cuda()
            with torch.no_grad():
                output = model(data).data
                all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
                for i in range(output.size(0)):
                    boxes = all_boxes[i]
                    boxes = nms(boxes, nms_thresh)
                    if dataset_use == 'ucf101-24':
                        detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
                        if not os.path.exists('ucf_detections'):
                            os.makedirs(current_dir)
                        if not os.path.exists(current_dir):
                            os.makedirs(current_dir)
                    else:
                        detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
                        if not os.path.exists('jhmdb_detections'):
                            os.mkdir(current_dir)
                        if not os.path.exists(current_dir):
                            os.mkdir(current_dir)

                    with open(detection_path, 'w+') as f_detect:
                        for box in boxes:
                            x1 = round(float(box[0] - box[2] / 2.0) * 320.0)
                            y1 = round(float(box[1] - box[3] / 2.0) * 240.0)
                            x2 = round(float(box[0] + box[2] / 2.0) * 320.0)
                            y2 = round(float(box[1] + box[3] / 2.0) * 240.0)

                            det_conf = float(box[4])
                            for j in range((len(box) - 5) // 2):
                                cls_conf = float(box[5 + 2 * j].item())

                                if type(box[6 + 2 * j]) == torch.Tensor:
                                    cls_id = int(box[6 + 2 * j].item())
                                else:
                                    cls_id = int(box[6 + 2 * j])
                                prob = det_conf * cls_conf

                                f_detect.write(
                                    str(int(box[6]) + 1) + ' ' + str(prob) + ' ' + str(x1) + ' ' + str(y1) + ' ' + str(
                                        x2) + ' ' + str(y2) + '\n')
                    truths = target[i].view(-1, 5)
                    num_gts = truths_length(truths)

                    total = total + num_gts

                    for i in range(len(boxes)):
                        if boxes[i][4] > 0.25:
                            proposals = proposals + 1

                    for i in range(num_gts):
                        box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]
                        best_iou = 0
                        best_j = -1
                        for j in range(len(boxes)):
                            iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
                            if iou > best_iou:
                                best_j = j
                                best_iou = iou

                        if best_iou > iou_thresh:
                            total_detected += 1
                            if int(boxes[best_j][6]) == box_gt[6]:
                                correct_classification += 1

                        if best_iou > iou_thresh and int(boxes[best_j][6]) == box_gt[6]:
                            correct = correct + 1

                precision = 1.0 * correct / (proposals + eps)
                recall = 1.0 * correct / (total + eps)
                fscore = 2.0 * precision * recall / (precision + recall + eps)
                logging(
                    "[%d/%d] precision: %f, recall: %f, fscore: %f" % (batch_idx, nbatch, precision, recall, fscore))

        classification_accuracy = 1.0 * correct_classification / (total_detected + eps)
        locolization_recall = 1.0 * total_detected / (total + eps)

        print("Classification accuracy: %.3f" % classification_accuracy)
        print("Locolization recall: %.3f" % locolization_recall)

        return fscore

test的主要任務是返回一個fscore,模型根據這個fscore,保存一個best 參數。

2. 加載數據集

和train一樣的,test需要加載數據集,使用的是listDataset類,和train中使用的listDataset類一致,只不過在加載時模式不一樣,存在差異。進入(step into)listDataset類查看(__init__是一致的,只不過__getitem__不同):

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'
        imgpath = self.lines[index].rstrip()
        self.shape = (224, 224)
        if self.train: # For Training
            jitter = 0.2
            hue = 0.1
            saturation = 1.5 
            exposure = 1.5
            clip, label = load_data_detection(self.base_path, imgpath,  self.train, self.clip_duration, self.shape, self.dataset_use, jitter, hue, saturation, exposure)
        else: # For Testing
            frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use)
            clip = [img.resize(self.shape) for img in clip]
        if self.transform is not None:
            clip = [self.transform(img) for img in clip]
        # (self.duration, -1) + self.shape = (8, -1, 224, 224)
        clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3)
        if self.target_transform is not None:
            label = self.target_transform(label)
        self.seen = self.seen + self.num_workers
        if self.train:
            return (clip, label)
        else:
            return (frame_idx, clip, label)

定位到代碼:frame_idx, clip, label = load_data_detection(self.base_path, imgpath, False, self.clip_duration, self.shape, self.dataset_use),並進入(step into)

def load_data_detection(base_path, imgpath, train, train_dur, shape, dataset_use='ucf101-24', jitter=0.2, hue=0.1, saturation=1.5, exposure=1.5):
	......
    for i in reversed(range(train_dur)):
        # make it as a loop
        i_temp = im_ind - i * d
        while i_temp < 1:
            i_temp = max_num + i_temp
        while i_temp > max_num:
            i_temp = i_temp - max_num

        if dataset_use == 'ucf101-24':
            path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.jpg'.format(i_temp))
        else:
            path_tmp = os.path.join(base_path, 'rgb-images', im_split[0], im_split[1] ,'{:05d}.png'.format(i_temp))

        clip.append(Image.open(path_tmp).convert('RGB'))

    if train: # Apply augmentation
	......
    else: # No augmentation
        label = torch.zeros(50*5)
        try:
            tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32'))
        except Exception:
            tmp = torch.zeros(1,5)

        tmp = tmp.view(-1)
        tsz = tmp.numel()

        if tsz > 50*5:
            label = tmp[0:50*5]
        elif tsz > 0:
            label[0:tsz] = tmp
    if train:
        return clip, label
    else:
        return im_split[0] + '_' +im_split[1] + '_' + im_split[2], clip, label

和在train中一樣,clip是連續採樣的圖像序列,但是在test中是不需要數據增強的。從labelpath中讀取label:tmp = torch.from_numpy(read_truths_args(labpath, 8.0/clip[0].width).astype('float32')),查看完整代碼如下:

def read_truths_args(lab_path, min_box_scale):
    truths = read_truths(lab_path)
    new_truths = []
    for i in range(truths.shape[0]):
        cx = (truths[i][1] + truths[i][3]) / (2 * 320)
        cy = (truths[i][2] + truths[i][4]) / (2 * 240)
        imgw = (truths[i][3] - truths[i][1]) / 320
        imgh = (truths[i][4] - truths[i][2]) / 240
        truths[i][0] = truths[i][0] - 1
        truths[i][1] = cx
        truths[i][2] = cy
        truths[i][3] = imgw
        truths[i][4] = imgh
        if truths[i][3] < min_box_scale:
            continue
        new_truths.append([truths[i][0], truths[i][1], truths[i][2], truths[i][3], truths[i][4]])
    return np.array(new_truths)

cx和cy是target的中心座標,並歸一化到0-1之間,同理imgw和imgh則是target的寬和高,也歸一化到0-1之間,如果target的寬太小if truths[i][3] < min_box_scale:,則省略,把轉換過後符合要求的target放入new_truths中。這裏可以看到數據集的寬和高都是固定320240320*240,如果使用其他的數據集,這裏是需要修改的
繼續返回到listDataset類的__getitem__中,tmp是獲取的標籤,tsz則是計算tmp中的元素數量。由於label = torch.zeros(50*5),label只能保存50個target,所有會對tsz進行判斷,如果tsz太大超出label可以保存的範圍,則指保留tmp前50個標籤,反之全部保存到label中。最後通過load_data_detection返回frame_idx, clip, label
在這裏插入圖片描述
和train一樣,clip被拼接成shape爲3162242243*16*224*224大小的張量:clip = torch.cat(clip, 0).view((self.clip_duration, -1) + self.shape).permute(1, 0, 2, 3),繼續返回到test(epoch)中:

		......
        logging('validation at epoch %d' % (epoch))
        model.eval()

        for batch_idx, (frame_idx, data, target) in enumerate(test_loader):
            if use_cuda:
                data = data.cuda()
            with torch.no_grad():
                output = model(data).data
                all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1)
                for i in range(output.size(0)):
                    boxes = all_boxes[i]
                    boxes = nms(boxes, nms_thresh)
                    if dataset_use == 'ucf101-24':
                        detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
                        if not os.path.exists('ucf_detections'):
                            os.makedirs(current_dir)
                        if not os.path.exists(current_dir):
                            os.makedirs(current_dir)
                    else:
                        detection_path = os.path.join('jhmdb_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('jhmdb_detections', 'detections_' + str(epoch))
                        if not os.path.exists('jhmdb_detections'):
                            os.mkdir(current_dir)
                        if not os.path.exists(current_dir):
                            os.mkdir(current_dir)
		......          

3.模型輸出

加載完數據集之後,則獲取數據集並計算。得到模型的輸出:output = model(data).data,其shape爲4145774*145*7*7,下面則是獲取所有的boxes:all_boxes = get_region_boxes(output, conf_thresh_valid, num_classes, anchors, num_anchors, 0, 1),進入,查看完整代碼:

def get_region_boxes(output, conf_thresh, num_classes, anchors, num_anchors, only_objectness=1, validation=False):
    anchor_step = len(anchors)//num_anchors
    if output.dim() == 3:
        output = output.unsqueeze(0)
    batch = output.size(0)
    assert(output.size(1) == (5+num_classes)*num_anchors)
    h = output.size(2)
    w = output.size(3)
    t0 = time.time()
    all_boxes = []
    output = output.view(batch*num_anchors, 5+num_classes, h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w)
    grid_x = torch.linspace(0, w-1, w).repeat(h,1).repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
    grid_y = torch.linspace(0, h-1, h).repeat(w,1).t().repeat(batch*num_anchors, 1, 1).view(batch*num_anchors*h*w).cuda()
    xs = torch.sigmoid(output[0]) + grid_x
    ys = torch.sigmoid(output[1]) + grid_y
    anchor_w = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([0]))
    anchor_h = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([1]))
    anchor_w = anchor_w.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
    anchor_h = anchor_h.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
    ws = torch.exp(output[2]) * anchor_w
    hs = torch.exp(output[3]) * anchor_h
    det_confs = torch.sigmoid(output[4])
    cls_confs = torch.nn.Softmax()(Variable(output[5:5+num_classes].transpose(0,1))).data
    cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)
    cls_max_confs = cls_max_confs.view(-1)
    cls_max_ids = cls_max_ids.view(-1)
    t1 = time.time()
    sz_hw = h*w
    sz_hwa = sz_hw*num_anchors
    det_confs = convert2cpu(det_confs)
    cls_max_confs = convert2cpu(cls_max_confs)
    cls_max_ids = convert2cpu_long(cls_max_ids)
    xs = convert2cpu(xs)
    ys = convert2cpu(ys)
    ws = convert2cpu(ws)
    hs = convert2cpu(hs)
    if validation:
        cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
    t2 = time.time()
    for b in range(batch):
        boxes = []
        for cy in range(h):
            for cx in range(w):
                for i in range(num_anchors):
                    ind = b*sz_hwa + i*sz_hw + cy*w + cx
                    det_conf =  det_confs[ind]
                    if only_objectness:
                        conf =  det_confs[ind]
                    else:
                        conf = det_confs[ind] * cls_max_confs[ind]
    
                    if conf > conf_thresh:
                        bcx = xs[ind]
                        bcy = ys[ind]
                        bw = ws[ind]
                        bh = hs[ind]
                        cls_max_conf = cls_max_confs[ind]
                        cls_max_id = cls_max_ids[ind]
                        box = [bcx/w, bcy/h, bw/w, bh/h, det_conf, cls_max_conf, cls_max_id]
                        if (not only_objectness) and validation:
                            for c in range(num_classes):
                                tmp_conf = cls_confs[ind][c]
                                if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
                                    box.append(tmp_conf)
                                    box.append(c)
                        boxes.append(box)
        all_boxes.append(boxes)
    t3 = time.time()
    if False:
        print('---------------------------------')
        print('matrix computation : %f' % (t1-t0))
        print('        gpu to cpu : %f' % (t2-t1))
        print('      boxes filter : %f' % (t3-t2))
        print('---------------------------------')
    return all_boxes

將output的shape進行重塑:output = output.view(batch*num_anchors, 5+num_classes, h*w).transpose(0,1).contiguous().view(5+num_classes, batch*num_anchors*h*w),重塑後的shape爲:2998029*980,即和(5+numClass)(batchsizenumAnchorhw)(5+numClass)*(batchsize*numAnchor*h*w)對應起來。grid_x和grid_y是對應anchor在feature map上的座標(左上角)索引,和在train中提到的一致,在Pytorch|YOWO原理及代碼詳解(二)中已講解過。加上座標索引,即偏移量,xs和ys是anchor在feature map上的絕對座標。
下面是獲取5個anchor的w和h:

	anchor_w = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([0]))
    anchor_h = torch.Tensor(anchors).view(num_anchors, anchor_step).index_select(1, torch.LongTensor([1]))

並把每個anchor的w和h對輸出的980個anchor進行一一匹配:

   anchor_w = anchor_w.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()
   anchor_h = anchor_h.repeat(batch, 1).repeat(1, 1, h*w).view(batch*num_anchors*h*w).cuda()

在這裏插入圖片描述
output的shape爲2998029*980,其中output[2]對應的是980個anchor預測框的w,output[2]的shape爲19801*980,同理,output[2]對應的是980個anchor預測框的h,其shape爲19801*980。那麼:

    ws = torch.exp(output[2]) * anchor_w
    hs = torch.exp(output[3]) * anchor_h

則是計算每個anchor預測框的w和h,和yolov2的邊界框預測公式:
bx=σ(tx)+cxb_{x}=\sigma\left(t_{x}\right)+c_{x}
by=σ(ty)+cyb_{y}=\sigma\left(t_{y}\right)+c_{y}
bw=pwetwb_{w}=p_{w} e^{t_w}
bh=phethb_{h}=p_{h} e^{t_h}
完全對應起來。
det_confs獲取預測框的置信度,cls_confs則是使用了softmax層獲取每個類別的得分,其shape爲98024980*24,即是980個anchor預測的24個類別得分。通過cls_max_confs, cls_max_ids = torch.max(cls_confs, 1)計算每個anchor的類別預測的最大概率值和其索引,並同時重塑成shape爲980980的張量,以和輸出的anchor完全對應起來。
接下來則是把一些列張量放在CPU上:

    ......
    sz_hw = h*w
    sz_hwa = sz_hw*num_anchors
    det_confs = convert2cpu(det_confs)
    cls_max_confs = convert2cpu(cls_max_confs)
    cls_max_ids = convert2cpu_long(cls_max_ids)
    xs = convert2cpu(xs)
    ys = convert2cpu(ys)
    ws = convert2cpu(ws)
    hs = convert2cpu(hs)
    if validation:
        cls_confs = convert2cpu(cls_confs.view(-1, num_classes))
    ......

接下來則是按batch進行計算,對應batch中的一份數據集,先按feature map的x,y進行遍歷,即獲得每個grid cell的偏移量(左上角座標)。ind = b*sz_hwa + i*sz_hw + cy*w + cx,ind則是獲取對應anchor的索引,即batch中的第b份數據,第i個anchor在cy,cx上的索引batch中的一份數據有sz_hwa(245)個anchor,一張圖被分爲了sz_hw(49) 個grid cell,每個grid cell都有num_anchors(5)個anchor

                    if only_objectness:
                        conf =  det_confs[ind]
                    else:
                        conf = det_confs[ind] * cls_max_confs[ind]

訓練過程的test模式(only_objectness = 0時),其置信度是等於det_confs * cls_max_confs。如果置信度conf是大於閾值的,獲取對應的值,放入box中:box = [bcx/w, bcy/h, bw/w, bh/h, det_conf, cls_max_conf, cls_max_id]。下面這段代碼則是把預測的24個類別的概率和對應的ID全部放進box中,遍歷完之後,box的長度爲53(7+232=537+23*2=53)。

                        if (not only_objectness) and validation:
                            for c in range(num_classes):
                                tmp_conf = cls_confs[ind][c]
                                if c != cls_max_id and det_confs[ind]*tmp_conf > conf_thresh:
                                    box.append(tmp_conf)
                                    box.append(c)

在這裏插入圖片描述
最後把一個batch的所有符合條件的boxes放入all_boxes中:all_boxes.append(boxes)。接下來遍歷一個batch的所有boxes:boxes = all_boxes[i],然後使用非極大值抑制抑制得到最後的輸出:boxes = nms(boxes, nms_thresh),完整代碼如下:

def nms(boxes, nms_thresh):
    if len(boxes) == 0:
        return boxes

    det_confs = torch.zeros(len(boxes))
    for i in range(len(boxes)):
        det_confs[i] = 1-boxes[i][4]                

    _,sortIds = torch.sort(det_confs)
    out_boxes = []
    for i in range(len(boxes)):
        box_i = boxes[sortIds[i]]
        if box_i[4] > 0:
            out_boxes.append(box_i)
            for j in range(i+1, len(boxes)):
                box_j = boxes[sortIds[j]]
                if bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh:
                    #print(box_i, box_j, bbox_iou(box_i, box_j, x1y1x2y2=False))
                    box_j[4] = 0
    return out_boxes

det_confs存儲檢測出的box的置信度值。由於torch.sort默認是升序排列,但需要的是置信度得分高的在前面,所以對det_confs進行det_confs[i] = 1-boxes[i][4]處理。sortIds則是排序後的anchor索引。非極大值抑制的核心是將重合的邊界框進行融合,如何判斷邊界框與邊界框的重合程度呢?通過計算IOU值:bbox_iou(box_i, box_j, x1y1x2y2=False) > nms_thresh,當大於閾值時,則會省略多餘的box_j。說實話,我絕對這裏的NMS是存在問題的,1.沒有設置x1y1x2y2=True,應該是需要考慮座標偏移的,而不是僅僅計算邊界框與邊框面積的重合程度;2. 沒有和過濾掉的邊界框進行融合。當然NMS是存在多種算法的,有興趣的可以自己再去了解一下。
接下來則是創建檢測的路徑:

                    if dataset_use == 'ucf101-24':
                        detection_path = os.path.join('ucf_detections', 'detections_' + str(epoch), frame_idx[i])
                        current_dir = os.path.join('ucf_detections', 'detections_' + str(epoch))
                        if not os.path.exists('ucf_detections'):
                            os.makedirs(current_dir)
                        if not os.path.exists(current_dir):
                            os.makedirs(current_dir)

打開檢測文件:with open(detection_path, 'w+') as f_detect:。獲取一個box檢測到的所有類別的得分:cls_conf = float(box[5 + 2 * j].item())。計算概率值:prob = det_conf * cls_conf,即置信度乘以類別得分,並把所有結構寫入到txt中:
在這裏插入圖片描述
truths則是獲取對應圖片的標註:truths = target[i].view(-1, 5),其shape爲50550*5。通過truths_length獲取真實的target數量。

                    for i in range(len(boxes)):
                        if boxes[i][4] > 0.25:
                            proposals = proposals + 1

把置信度大於閾值 0.25的boxe作爲候選框proposals。把truths放入box_gt :box_gt = [truths[i][1], truths[i][2], truths[i][3], truths[i][4], 1.0, 1.0, truths[i][0]]。遍歷檢測出的boxes,並計算iou值,得到best iou值和對應的box ID:

                        for j in range(len(boxes)):
                            iou = bbox_iou(box_gt, boxes[j], x1y1x2y2=False)
                            if iou > best_iou:
                                best_j = j
                                best_iou = iou

如果IOU值大於閾值,則認爲正確檢測到了一個,如果類別還預測對了,則爲真陽性:

                        if best_iou > iou_thresh:
                            total_detected += 1
                            if int(boxes[best_j][6]) == box_gt[6]:
                                correct_classification += 1

                        if best_iou > iou_thresh and int(boxes[best_j][6]) == box_gt[6]:
                            correct = correct + 1                           

剩下的就是各項指標計算:

		......		
                precision = 1.0 * correct / (proposals + eps)
                recall = 1.0 * correct / (total + eps)
                fscore = 2.0 * precision * recall / (precision + recall + eps)
                logging(
                    "[%d/%d] precision: %f, recall: %f, fscore: %f" % (batch_idx, nbatch, precision, recall, fscore))
        classification_accuracy = 1.0 * correct_classification / (total_detected + eps)
        locolization_recall = 1.0 * total_detected / (total + eps)
        print("Classification accuracy: %.3f" % classification_accuracy)
        print("Locolization recall: %.3f" % locolization_recall)
        return fscore

當計算完,返回主程序,根據fscore來保存best模型。

           if is_best:
                print("New best fscore is achieved: ", fscore)
                print("Previous fscore was: ", best_fscore)
                best_fscore = fscore
            # Save the model to backup directory
            state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'fscore': fscore
            }
            save_checkpoint(state, is_best, backupdir, opt.dataset, clip_duration)
            logging('Weights are saved to backup directory: %s' % (backupdir))

test流程基本分析完畢,剩下的請見:
Pytorch|YOWO原理及代碼詳解(四)(待更新)

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章