SSD源碼閱讀一

 個人博客:http://www.chenjianqu.com/

原文鏈接:http://www.chenjianqu.com/show-91.html

上一篇博客讀了SSD的論文<SSD論文筆記>,原作者是在Caffe上實現,但是我對這個框架不太熟悉,因此找大佬們在Pytorch上的實現:https://github.com/amdegroot/ssd.pytorch ,github裏面給出了安裝和運行的步驟。這篇博客主要是通過閱讀該項目的源碼,加深對SSD和Pytorch的理解。

    在下載了數據集和vgg權重文件後,可以通過train.py訓練SSD模型,因此我從train.py文件開始閱讀。開始:::

    train.py裏面首先通過參數解析器讀取命令行傳入的參數,根據是否啓用GPU確定默認的張量類型和創建對應的文件夾:

 train.py

def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")
#初始化參數解析器
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With Pytorch')
#創建一個互斥組,組內參數不可同時出現
train_set = parser.add_mutually_exclusive_group()
#數據集選擇
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
                    type=str, help='VOC or COCO')
#數據集路徑,VOC_ROOT是VOC數據集的目錄,定義在data/voc0712.py裏
parser.add_argument('--dataset_root', default=VOC_ROOT,
                    help='Dataset root directory path')
#backbone網絡
parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
                    help='Pretrained base model')
#batch_size
parser.add_argument('--batch_size', default=32, type=int,
                    help='Batch size for training')
#恢復訓練的權重目錄
parser.add_argument('--resume', default=None, type=str,
                    help='Checkpoint state_dict file to resume training from')
#開始迭代的次數
parser.add_argument('--start_iter', default=0, type=int,
                    help='Resume training at this iter')
#數據處理使用的線程數
parser.add_argument('--num_workers', default=4, type=int,
                    help='Number of workers used in dataloading')
#啓用GPU訓練
parser.add_argument('--cuda', default=True, type=str2bool,
                    help='Use CUDA to train model')
#學習率
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
                    help='initial learning rate')
#動量係數
parser.add_argument('--momentum', default=0.9, type=float,
                    help='Momentum value for optim')
#權重衰減
parser.add_argument('--weight_decay', default=5e-4, type=float,
                    help='Weight decay for SGD')
#學習率更新系數
parser.add_argument('--gamma', default=0.1, type=float,
                    help='Gamma update for SGD')
#是否用到visdom
parser.add_argument('--visdom', default=False, type=str2bool,
                    help='Use visdom for loss visualization')
#權重保存路徑
parser.add_argument('--save_folder', default='weights/',
                    help='Directory for saving checkpoint models')
#獲得參數空間的對象
args = parser.parse_args()
#根據是否啓用GPU確定默認的張量類型
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but aren't " +
              "using CUDA.\nRun with --cuda for optimal training speed.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')
#權重文件夾
if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)

    接着會執行train()函數。在該函數裏面,通過傳入的dataset和dataset_root參數判斷讀取初始化相應的數據集:

 train.py  train()

if args.dataset == 'COCO':
    if args.dataset_root == VOC_ROOT:
        if not os.path.exists(COCO_ROOT):
            parser.error('Must specify dataset_root if specifying dataset')
        print("WARNING: Using default COCO dataset_root because " +
          "--dataset_root was not specified.")
        args.dataset_root = COCO_ROOT
    #讀取該數據集的配置,coco定義在data/config裏面
    cfg = coco 
    #數據集對象
    dataset = COCODetection(root=args.dataset_root,
                            transform=SSDAugmentation(cfg['min_dim'],
                            MEANS))
elif args.dataset == 'VOC':
    if args.dataset_root == COCO_ROOT:
        parser.error('Must specify dataset if specifying dataset_root')
    cfg = voc
    dataset = VOCDetection(root=args.dataset_root,
                           transform=SSDAugmentation(cfg['min_dim'],
                         MEANS))

    上面代碼裏面的cocovoc是定義在data/config.py裏面的字典,是對應數據集網絡配置的參數。

data/config.py

# SSD300 CONFIGS
#VOC數據的網絡配置
voc = {
    'num_classes': 21,#類別數
    'lr_steps': (80000, 100000, 120000),#學習率下降的步數
    'max_iter': 120000,#最大迭代次數
    'feature_maps': [38, 19, 10, 5, 3, 1],#輸出特徵圖的尺寸
    'min_dim': 300,#輸入圖片的短邊長
    'steps': [8, 16, 32, 64, 100, 300],#輸出特徵圖每個像素對應到原圖的大小,也就是原圖到特徵圖的下采樣係數
    'min_sizes': [30, 60, 111, 162, 213, 264],#計算先驗框用到的Smin和Smax
    'max_sizes': [60, 111, 162, 213, 264, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],#各輸出特徵圖的長寬比
    'variance': [0.1, 0.2],
    'clip': True,#是否截斷先驗框在圖像邊界內,在prior_box裏會用到
    'name': 'VOC',#名稱
}
coco = {
    'num_classes': 201,
    'lr_steps': (280000, 360000, 400000),
    'max_iter': 400000,
    'feature_maps': [38, 19, 10, 5, 3, 1],
    'min_dim': 300,
    'steps': [8, 16, 32, 64, 100, 300],
    'min_sizes': [21, 45, 99, 153, 207, 261],
    'max_sizes': [45, 99, 153, 207, 261, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
    'variance': [0.1, 0.2],
    'clip': True,
    'name': 'COCO',
}

     在處理任何機器學習 問題之前都需要數據讀取, 並進行預處理。 PyTorch 提供了很多工具使得數據的讀取和預處理變得很容易 。  torch.utils.data.Dataset 是代表這一數據的抽象類,可以自己定義數據,只需繼承和重寫這個抽象類,需要定義__len__和__getitem__這兩個函數。回到train.py train()裏的代碼,這裏自定義的數據集對象是COCODetection和VOCDetection,分別定義在data/coco.py和voc0712.py裏面,後者爲例,數據集對象調用如下:

 train.py  train()

 

dataset = VOCDetection(root=args.dataset_root,
                       transform=SSDAugmentation(cfg['min_dim'],MEANS)
  )

    該類構造函數的第一個參數是VOC數據集的目錄,第二個參數是數據增強對象,我們來看一下:SSDAugmentation類定義在augmentations.py裏面。

augmentations.py

class SSDAugmentation(object):
#參數:輸入分辨率,數據集RGB均值
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        #將多個數據變形壓縮到一個
        self.augment = Compose([
            ConvertFromInts(),#將整形圖像數據轉換爲float32數據
            ToAbsoluteCoords(),
            PhotometricDistort(),#光度增強
            Expand(self.mean),#將圖像隨機擴展,拓展的像素值爲self.mean
            RandomSampleCrop(),#隨機採樣裁切
            RandomMirror(),#將圖像和先驗框隨機水平翻轉
            ToPercentCoords(),#將先驗框的中心和長寬除以圖像的寬和高
            Resize(self.size),#將圖像數據縮放到self.size*self.size
            SubtractMeans(self.mean) #將圖像數據減去均值
        ])
    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

    這裏面調用的數據增強類都定義在augmentations.py裏面,可以看一下Compose這個類,將多個數據增強壓縮到一個:

augmentations.py

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms
    #該函數將類實例變成可調用對象
    def __call__(self, img, boxes=None, labels=None):
        #依次執行數據變形
        for t in self.transforms:
            img, boxes, labels = t(img, boxes, labels)
        return img, boxes, labels

    回到train.py train()函數裏面,VOCDetection定義在voc0712.py裏面。如下:

data/voc0712.py

class VOCDetection(data.Dataset):
    def __init__(self, 
                 root,#數據集根目錄
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],#數據集
                 transform=None, #數據增強
                 target_transform=VOCAnnotationTransform(),#標籤數據增強
                 dataset_name='VOC0712'#數據集名稱
        ):
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self.name = dataset_name
        
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
        
        self.ids = list() #ids存放所有的圖片路徑
        for (year, name) in image_sets:
            rootpath = osp.join(self.root, 'VOC' + year)
            for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append((rootpath, line.strip()))

                
    #需要覆蓋的函數,獲取索引數據,返回數據集中的第i個樣本
    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        return im, gt
    
    #需要覆蓋的函數,返回數據集的大小
    def __len__(self):
        return len(self.ids)
    
    #獲取第i個圖片
    def pull_item(self, index):
        img_id = self.ids[index]
        #讀取xml註釋文件
        target = ET.parse(self._annopath % img_id).getroot()
        #讀取圖像
        img = cv2.imread(self._imgpath % img_id)
        height, width, channels = img.shape
        #xml解析
        if self.target_transform is not None:
            target = self.target_transform(target, width, height)
        #圖像數據增強
        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            img = img[:, :, (2, 1, 0)]#將BGR圖像轉換爲RBG圖像
            #將gt box和標籤融合爲一個張量
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target, height, width

#其中xml解析器如下
class VOCAnnotationTransform(object):
    def __init__(self, class_to_ind=None, keep_difficult=False):
        self.class_to_ind = class_to_ind or dict(
            zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        self.keep_difficult = keep_difficult

#參數:xml的內容,圖片的寬,圖片的高
    def __call__(self, target, width, height):
        """
        Arguments:
            target (annotation) : the target annotation to be made usable
                will be an ET.Element
        Returns:
            a list containing lists of bounding boxes  [bbox coords, class name]
        """
        res = []
	#對於每個<object>
        for obj in target.iter('object'):
	#判斷difficult的目標
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
	#獲取目標<name>
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
			
			bndbox = []
	#獲取目標的左上角和右下角的座標
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                #轉換爲相對值
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
			
	#獲取目標的標籤
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
			
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]

        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

    再回到train.py train()函數,定義數據集之後,再調用build_ssd()這個函數構建SSD網絡結構,該函數定義在ssd.py裏面,如下:

ssd.py  build_ssd()

def build_ssd(phase, size=300, num_classes=21):
    #測試模式或訓練模式
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
        
    #輸入分辨率只能是300
    if size != 300:
        print("ERROR: You specified size " + repr(size) + ". However, " +
              "currently only SSD300 (size=300) is supported!")
        return
        
    #獲取vgg網絡,參數:網絡配置、輸入通道數
    vggnet=vgg(base[str(size)], 3)
    
    #構造額外的層,參數:網絡配置、輸入通道數
    extras_layers=add_extras(extras[str(size)], 1024)
    
    #將vgg網絡和額外構造的網絡連接起來,參數:vgg,額外層、網絡配置、類別數
    base_, extras_, head_ = multibox(vggnet,extras_layers,mbox[str(size)], num_classes)
    
    #構造SSD網絡
    return SSD(phase, size, base_, extras_, head_, num_classes)

    SSD使用VGG作爲backbone,詳情:<SSD論文筆記>,<VGG16>。VGG16的網絡結構如下圖的D網絡:

1.jpg

2.jpg

    這裏使用pytorch預訓練的vgg權重,需要看一下vgg的pytorch實現:https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 。這裏的實現如下:   

ssd.py  vgg()

def vgg(cfg, i, batch_norm=False):#cfg是vgg的配置,i是輸入通道數
    layers = []
    in_channels = i
#cfg=[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',512, 512, 512]
#其中M表示最大池化層,C表示池化的天花板模式,即當池化的size=stride但是height/stride不是整數時,在旁邊補上-NAN的值
#數字則表示輸出通道數
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    
    conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)#dilation表示使用空洞卷積
    conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
    layers += [pool5, conv6,
               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    return layers

這段代碼得到的vgg如下:
0 Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 ReLU(inplace)
2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 ReLU(inplace)
4 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
5 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
6 ReLU(inplace)
7 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
8 ReLU(inplace)
9 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
10 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
11 ReLU(inplace)
12 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
13 ReLU(inplace)
14 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
15 ReLU(inplace)
16 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
17 Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
18 ReLU(inplace)
19 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
20 ReLU(inplace)
21 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
22 ReLU(inplace)
23 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
24 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
25 ReLU(inplace)
26 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
27 ReLU(inplace)
28 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
29 ReLU(inplace)
30 MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
31 Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
32 ReLU(inplace)
33 Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
34 ReLU(inplace)

   SSD在vgg網絡的末端還另外增加幾層卷積層,如下圖:

3.jpg

    這裏實現如下:

ssd.py  add_extras()

#構建額外的層,參數:網絡配置、輸入通道數
#cfg=[256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],i=1024
def add_extras(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    flag = False
    for k, v in enumerate(cfg):
        if in_channels != 'S':
            if v == 'S':
                layers += [nn.Conv2d(in_channels, cfg[k + 1],kernel_size=(1, 3)[flag], stride=2, padding=1)]
            else:
                layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
            flag = not flag
        in_channels = v
    return layers

得到的layers如下:
0 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
2 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
4 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
6 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))

    然後SSD再從多個特徵圖上預測先驗框偏移和置信度,代碼如下:

ssd.py  multibox()

#cfg=[4, 6, 6, 6, 4, 4]
def multibox(vgg, extra_layers, cfg, num_classes):
    loc_layers = []
    conf_layers = []
    #從上面vgg()的結果可知,21就是指conv4_3卷積,-2指vgg的conv7(fc7)層。
    vgg_source = [21, -2]
    #在vgg的21、-2層增加定位網絡(預測先驗框偏移)和置信度網絡(預測置信度)
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,cfg[k] * 4, kernel_size=3, padding=1)] #定位網絡
        conf_layers +=[nn.Conv2d(vgg[v].out_channels,cfg[k] * num_classes, kernel_size=3, padding=1)]#置信度網絡
    #在額外層的某些層增加定位網絡和置信度網絡
    #A[1::2]表示A[1,3,5,7,...];enumerate(A,2)表示迭代的i從2開始
    #因此下面的語句表示取額外層的:{
1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
}   
    #後面會將這些額外層連接到定位網絡和置信度網絡
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]* 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]* num_classes, kernel_size=3, padding=1)]
    
    return vgg, extra_layers, (loc_layers, conf_layers)
    
    
base_, extras_, head_ = multibox(vggnet,extras_layers,mbox['300'], 21)

得到的head_包含定位網絡和置信度網絡:
定位網絡
0 Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2 Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
5 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
置信度網絡
0 Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2 Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
5 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    上面這三個函數得到SSD的三部分網絡之後,在SSD這個類裏面組合起來。

ssd.py

class SSD(nn.Module):
    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase #訓練還是測試
        self.num_classes = num_classes #輸出類別數
        self.cfg = (coco, voc)[num_classes == 21] #若num_classes == 21,cfg=voc,否則cfg=coco
        #先驗框對象,定義在layers/functions/prior_box.py
        self.priorbox = PriorBox(self.cfg) 
        #設置先驗框,在計算loss時用到
        self.priors = Variable(self.priorbox.forward(), volatile=True)

        self.size = size
        #構建vgg網絡
        self.vgg = nn.ModuleList(base)
        #L2歸一化層,用於將vgg conv4_3進行縮放
        self.L2Norm = L2Norm(512, 20) 
        #構建額外層
        self.extras = nn.ModuleList(extras)
        #構建定位網絡
        self.loc = nn.ModuleList(head[0])
        #構建置信度網絡
        self.conf = nn.ModuleList(head[1])
        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

    def forward(self, x):#x: input image or batch of images. Shape: [batch,3,300,300].
        sources = list()
        loc = list()
        conf = list()
        
        #應用vgg從第一層到conv4_3 relu層
        for k in range(23):
            x = self.vgg[k](x)
        
        s = self.L2Norm(x)#從con4_3處連接L2歸一化層,這點論文中也有
        
        sources.append(s)#保存到sources,下面用到用於連接到定位網絡和置信度網絡
        #應用vgg從con4_3到fc7
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x)
        
        #應用額外層
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)
                
        #將定位網絡和置信度網絡連接到輸出特徵圖上,並進行維度變換
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
‘’‘
使用transpose或permute進行維度變換後,調用contiguous,然後方可使用view對維度進行變形。
網上有兩種說法,一種是維度變換後tensor在內存中不再是連續存儲的,而view操作要求連續存儲,所以需要contiguous,另一種是說維度變換後的變量是之前變量的淺複製,指向同一區域,即view操作會連帶原來的變量一同變形,這是不合法的,
’‘’
        #view函數的作用爲重構張量的維度,相當於numpy中resize()的功能
        #因爲定位網絡和置信度網絡都是卷積,得到是三維的特徵圖,這裏將各網絡的預測結果展開,並拼接在一起
        #o.size(0)是batch_size
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        
        if self.phase == "test":
            output = self.detect(
                loc.view(loc.size(0), -1, 4),                   # loc preds
                self.softmax(conf.view(conf.size(0), -1,
                             self.num_classes)),                # conf preds
                self.priors.type(type(x.data))                  # default boxes
            )
        else:
            output = (
                loc.view(loc.size(0), -1, 4),#最後將定位結果resize成(batch_size,先驗框的數量,總的先驗框的四個偏移)
                conf.view(conf.size(0), -1, self.num_classes),#將置信度結果resize成(batch_size,先驗框的數量,類別數)
                self.priors
            )
        return output

    #權重加載
    def load_weights(self, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    經過上面這些代碼,就構建完成了SSD的網絡結構。然後發現PriorBox和L2Norm這兩個類,這是啥?篇幅所限,下篇再見。

 

 

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章