個人博客:http://www.chenjianqu.com/
原文鏈接:http://www.chenjianqu.com/show-91.html
上一篇博客讀了SSD的論文<SSD論文筆記>,原作者是在Caffe上實現,但是我對這個框架不太熟悉,因此找大佬們在Pytorch上的實現:https://github.com/amdegroot/ssd.pytorch ,github裏面給出了安裝和運行的步驟。這篇博客主要是通過閱讀該項目的源碼,加深對SSD和Pytorch的理解。
在下載了數據集和vgg權重文件後,可以通過train.py訓練SSD模型,因此我從train.py文件開始閱讀。開始:::
train.py裏面首先通過參數解析器讀取命令行傳入的參數,根據是否啓用GPU確定默認的張量類型和創建對應的文件夾:
train.py
def str2bool(v): return v.lower() in ("yes", "true", "t", "1") #初始化參數解析器 parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With Pytorch') #創建一個互斥組,組內參數不可同時出現 train_set = parser.add_mutually_exclusive_group() #數據集選擇 parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'], type=str, help='VOC or COCO') #數據集路徑,VOC_ROOT是VOC數據集的目錄,定義在data/voc0712.py裏 parser.add_argument('--dataset_root', default=VOC_ROOT, help='Dataset root directory path') #backbone網絡 parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='Pretrained base model') #batch_size parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') #恢復訓練的權重目錄 parser.add_argument('--resume', default=None, type=str, help='Checkpoint state_dict file to resume training from') #開始迭代的次數 parser.add_argument('--start_iter', default=0, type=int, help='Resume training at this iter') #數據處理使用的線程數 parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') #啓用GPU訓練 parser.add_argument('--cuda', default=True, type=str2bool, help='Use CUDA to train model') #學習率 parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate') #動量係數 parser.add_argument('--momentum', default=0.9, type=float, help='Momentum value for optim') #權重衰減 parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD') #學習率更新系數 parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD') #是否用到visdom parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom for loss visualization') #權重保存路徑 parser.add_argument('--save_folder', default='weights/', help='Directory for saving checkpoint models') #獲得參數空間的對象 args = parser.parse_args() #根據是否啓用GPU確定默認的張量類型 if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print("WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') #權重文件夾 if not os.path.exists(args.save_folder): os.mkdir(args.save_folder)
接着會執行train()函數。在該函數裏面,通過傳入的dataset和dataset_root參數判斷讀取初始化相應的數據集:
train.py train()
if args.dataset == 'COCO': if args.dataset_root == VOC_ROOT: if not os.path.exists(COCO_ROOT): parser.error('Must specify dataset_root if specifying dataset') print("WARNING: Using default COCO dataset_root because " + "--dataset_root was not specified.") args.dataset_root = COCO_ROOT #讀取該數據集的配置,coco定義在data/config裏面 cfg = coco #數據集對象 dataset = COCODetection(root=args.dataset_root, transform=SSDAugmentation(cfg['min_dim'], MEANS)) elif args.dataset == 'VOC': if args.dataset_root == COCO_ROOT: parser.error('Must specify dataset if specifying dataset_root') cfg = voc dataset = VOCDetection(root=args.dataset_root, transform=SSDAugmentation(cfg['min_dim'], MEANS))
上面代碼裏面的coco和voc是定義在data/config.py裏面的字典,是對應數據集網絡配置的參數。
data/config.py
# SSD300 CONFIGS #VOC數據的網絡配置 voc = { 'num_classes': 21,#類別數 'lr_steps': (80000, 100000, 120000),#學習率下降的步數 'max_iter': 120000,#最大迭代次數 'feature_maps': [38, 19, 10, 5, 3, 1],#輸出特徵圖的尺寸 'min_dim': 300,#輸入圖片的短邊長 'steps': [8, 16, 32, 64, 100, 300],#輸出特徵圖每個像素對應到原圖的大小,也就是原圖到特徵圖的下采樣係數 'min_sizes': [30, 60, 111, 162, 213, 264],#計算先驗框用到的Smin和Smax 'max_sizes': [60, 111, 162, 213, 264, 315], 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],#各輸出特徵圖的長寬比 'variance': [0.1, 0.2], 'clip': True,#是否截斷先驗框在圖像邊界內,在prior_box裏會用到 'name': 'VOC',#名稱 } coco = { 'num_classes': 201, 'lr_steps': (280000, 360000, 400000), 'max_iter': 400000, 'feature_maps': [38, 19, 10, 5, 3, 1], 'min_dim': 300, 'steps': [8, 16, 32, 64, 100, 300], 'min_sizes': [21, 45, 99, 153, 207, 261], 'max_sizes': [45, 99, 153, 207, 261, 315], 'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]], 'variance': [0.1, 0.2], 'clip': True, 'name': 'COCO', }
在處理任何機器學習 問題之前都需要數據讀取, 並進行預處理。 PyTorch 提供了很多工具使得數據的讀取和預處理變得很容易 。 torch.utils.data.Dataset 是代表這一數據的抽象類,可以自己定義數據,只需繼承和重寫這個抽象類,需要定義__len__和__getitem__這兩個函數。回到train.py train()裏的代碼,這裏自定義的數據集對象是COCODetection和VOCDetection,分別定義在data/coco.py和voc0712.py裏面,後者爲例,數據集對象調用如下:
train.py train()
dataset = VOCDetection(root=args.dataset_root, transform=SSDAugmentation(cfg['min_dim'],MEANS) )
該類構造函數的第一個參數是VOC數據集的目錄,第二個參數是數據增強對象,我們來看一下:SSDAugmentation類定義在augmentations.py裏面。
augmentations.py
class SSDAugmentation(object): #參數:輸入分辨率,數據集RGB均值 def __init__(self, size=300, mean=(104, 117, 123)): self.mean = mean self.size = size #將多個數據變形壓縮到一個 self.augment = Compose([ ConvertFromInts(),#將整形圖像數據轉換爲float32數據 ToAbsoluteCoords(), PhotometricDistort(),#光度增強 Expand(self.mean),#將圖像隨機擴展,拓展的像素值爲self.mean RandomSampleCrop(),#隨機採樣裁切 RandomMirror(),#將圖像和先驗框隨機水平翻轉 ToPercentCoords(),#將先驗框的中心和長寬除以圖像的寬和高 Resize(self.size),#將圖像數據縮放到self.size*self.size SubtractMeans(self.mean) #將圖像數據減去均值 ]) def __call__(self, img, boxes, labels): return self.augment(img, boxes, labels)
這裏面調用的數據增強類都定義在augmentations.py裏面,可以看一下Compose這個類,將多個數據增強壓縮到一個:
augmentations.py
class Compose(object): def __init__(self, transforms): self.transforms = transforms #該函數將類實例變成可調用對象 def __call__(self, img, boxes=None, labels=None): #依次執行數據變形 for t in self.transforms: img, boxes, labels = t(img, boxes, labels) return img, boxes, labels
回到train.py train()函數裏面,VOCDetection定義在voc0712.py裏面。如下:
data/voc0712.py
class VOCDetection(data.Dataset): def __init__(self, root,#數據集根目錄 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],#數據集 transform=None, #數據增強 target_transform=VOCAnnotationTransform(),#標籤數據增強 dataset_name='VOC0712'#數據集名稱 ): self.root = root self.image_set = image_sets self.transform = transform self.target_transform = target_transform self.name = dataset_name self._annopath = osp.join('%s', 'Annotations', '%s.xml') self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg') self.ids = list() #ids存放所有的圖片路徑 for (year, name) in image_sets: rootpath = osp.join(self.root, 'VOC' + year) for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')): self.ids.append((rootpath, line.strip())) #需要覆蓋的函數,獲取索引數據,返回數據集中的第i個樣本 def __getitem__(self, index): im, gt, h, w = self.pull_item(index) return im, gt #需要覆蓋的函數,返回數據集的大小 def __len__(self): return len(self.ids) #獲取第i個圖片 def pull_item(self, index): img_id = self.ids[index] #讀取xml註釋文件 target = ET.parse(self._annopath % img_id).getroot() #讀取圖像 img = cv2.imread(self._imgpath % img_id) height, width, channels = img.shape #xml解析 if self.target_transform is not None: target = self.target_transform(target, width, height) #圖像數據增強 if self.transform is not None: target = np.array(target) img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) img = img[:, :, (2, 1, 0)]#將BGR圖像轉換爲RBG圖像 #將gt box和標籤融合爲一個張量 target = np.hstack((boxes, np.expand_dims(labels, axis=1))) return torch.from_numpy(img).permute(2, 0, 1), target, height, width #其中xml解析器如下 class VOCAnnotationTransform(object): def __init__(self, class_to_ind=None, keep_difficult=False): self.class_to_ind = class_to_ind or dict( zip(VOC_CLASSES, range(len(VOC_CLASSES)))) self.keep_difficult = keep_difficult #參數:xml的內容,圖片的寬,圖片的高 def __call__(self, target, width, height): """ Arguments: target (annotation) : the target annotation to be made usable will be an ET.Element Returns: a list containing lists of bounding boxes [bbox coords, class name] """ res = [] #對於每個<object> for obj in target.iter('object'): #判斷difficult的目標 difficult = int(obj.find('difficult').text) == 1 if not self.keep_difficult and difficult: continue #獲取目標<name> name = obj.find('name').text.lower().strip() bbox = obj.find('bndbox') bndbox = [] #獲取目標的左上角和右下角的座標 pts = ['xmin', 'ymin', 'xmax', 'ymax'] for i, pt in enumerate(pts): cur_pt = int(bbox.find(pt).text) - 1 #轉換爲相對值 cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height bndbox.append(cur_pt) #獲取目標的標籤 label_idx = self.class_to_ind[name] bndbox.append(label_idx) res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
再回到train.py train()函數,定義數據集之後,再調用build_ssd()這個函數構建SSD網絡結構,該函數定義在ssd.py裏面,如下:
ssd.py build_ssd()
def build_ssd(phase, size=300, num_classes=21): #測試模式或訓練模式 if phase != "test" and phase != "train": print("ERROR: Phase: " + phase + " not recognized") return #輸入分辨率只能是300 if size != 300: print("ERROR: You specified size " + repr(size) + ". However, " + "currently only SSD300 (size=300) is supported!") return #獲取vgg網絡,參數:網絡配置、輸入通道數 vggnet=vgg(base[str(size)], 3) #構造額外的層,參數:網絡配置、輸入通道數 extras_layers=add_extras(extras[str(size)], 1024) #將vgg網絡和額外構造的網絡連接起來,參數:vgg,額外層、網絡配置、類別數 base_, extras_, head_ = multibox(vggnet,extras_layers,mbox[str(size)], num_classes) #構造SSD網絡 return SSD(phase, size, base_, extras_, head_, num_classes)
SSD使用VGG作爲backbone,詳情:<SSD論文筆記>,<VGG16>。VGG16的網絡結構如下圖的D網絡:
這裏使用pytorch預訓練的vgg權重,需要看一下vgg的pytorch實現:https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 。這裏的實現如下:
ssd.py vgg()
def vgg(cfg, i, batch_norm=False):#cfg是vgg的配置,i是輸入通道數 layers = [] in_channels = i #cfg=[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',512, 512, 512] #其中M表示最大池化層,C表示池化的天花板模式,即當池化的size=stride但是height/stride不是整數時,在旁邊補上-NAN的值 #數字則表示輸出通道數 for v in cfg: if v == 'M': layers += [nn.MaxPool2d(kernel_size=2, stride=2)] elif v == 'C': layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] else: conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) if batch_norm: layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] else: layers += [conv2d, nn.ReLU(inplace=True)] in_channels = v pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)#dilation表示使用空洞卷積 conv7 = nn.Conv2d(1024, 1024, kernel_size=1) layers += [pool5, conv6, nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] return layers 這段代碼得到的vgg如下: 0 Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 1 ReLU(inplace) 2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 3 ReLU(inplace) 4 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 5 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 6 ReLU(inplace) 7 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 8 ReLU(inplace) 9 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 10 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 11 ReLU(inplace) 12 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 13 ReLU(inplace) 14 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 15 ReLU(inplace) 16 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True) 17 Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 18 ReLU(inplace) 19 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 20 ReLU(inplace) 21 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 22 ReLU(inplace) 23 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 24 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 25 ReLU(inplace) 26 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 27 ReLU(inplace) 28 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 29 ReLU(inplace) 30 MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False) 31 Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6)) 32 ReLU(inplace) 33 Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1)) 34 ReLU(inplace)
SSD在vgg網絡的末端還另外增加幾層卷積層,如下圖:
這裏實現如下:
ssd.py add_extras()
#構建額外的層,參數:網絡配置、輸入通道數 #cfg=[256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],i=1024 def add_extras(cfg, i, batch_norm=False): layers = [] in_channels = i flag = False for k, v in enumerate(cfg): if in_channels != 'S': if v == 'S': layers += [nn.Conv2d(in_channels, cfg[k + 1],kernel_size=(1, 3)[flag], stride=2, padding=1)] else: layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] flag = not flag in_channels = v return layers 得到的layers如下: 0 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) 1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 2 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) 3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 4 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) 5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) 6 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) 7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
然後SSD再從多個特徵圖上預測先驗框偏移和置信度,代碼如下:
ssd.py multibox()
#cfg=[4, 6, 6, 6, 4, 4] def multibox(vgg, extra_layers, cfg, num_classes): loc_layers = [] conf_layers = [] #從上面vgg()的結果可知,21就是指conv4_3卷積,-2指vgg的conv7(fc7)層。 vgg_source = [21, -2] #在vgg的21、-2層增加定位網絡(預測先驗框偏移)和置信度網絡(預測置信度) for k, v in enumerate(vgg_source): loc_layers += [nn.Conv2d(vgg[v].out_channels,cfg[k] * 4, kernel_size=3, padding=1)] #定位網絡 conf_layers +=[nn.Conv2d(vgg[v].out_channels,cfg[k] * num_classes, kernel_size=3, padding=1)]#置信度網絡 #在額外層的某些層增加定位網絡和置信度網絡 #A[1::2]表示A[1,3,5,7,...];enumerate(A,2)表示迭代的i從2開始 #因此下面的語句表示取額外層的:{ 1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) 7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1)) } #後面會將這些額外層連接到定位網絡和置信度網絡 for k, v in enumerate(extra_layers[1::2], 2): loc_layers += [nn.Conv2d(v.out_channels, cfg[k]* 4, kernel_size=3, padding=1)] conf_layers += [nn.Conv2d(v.out_channels, cfg[k]* num_classes, kernel_size=3, padding=1)] return vgg, extra_layers, (loc_layers, conf_layers) base_, extras_, head_ = multibox(vggnet,extras_layers,mbox['300'], 21) 得到的head_包含定位網絡和置信度網絡: 定位網絡 0 Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 1 Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 2 Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 3 Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 4 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 5 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 置信度網絡 0 Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 1 Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 2 Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 3 Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 4 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 5 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
上面這三個函數得到SSD的三部分網絡之後,在SSD這個類裏面組合起來。
ssd.py
class SSD(nn.Module): def __init__(self, phase, size, base, extras, head, num_classes): super(SSD, self).__init__() self.phase = phase #訓練還是測試 self.num_classes = num_classes #輸出類別數 self.cfg = (coco, voc)[num_classes == 21] #若num_classes == 21,cfg=voc,否則cfg=coco #先驗框對象,定義在layers/functions/prior_box.py self.priorbox = PriorBox(self.cfg) #設置先驗框,在計算loss時用到 self.priors = Variable(self.priorbox.forward(), volatile=True) self.size = size #構建vgg網絡 self.vgg = nn.ModuleList(base) #L2歸一化層,用於將vgg conv4_3進行縮放 self.L2Norm = L2Norm(512, 20) #構建額外層 self.extras = nn.ModuleList(extras) #構建定位網絡 self.loc = nn.ModuleList(head[0]) #構建置信度網絡 self.conf = nn.ModuleList(head[1]) if phase == 'test': self.softmax = nn.Softmax(dim=-1) self.detect = Detect(num_classes, 0, 200, 0.01, 0.45) def forward(self, x):#x: input image or batch of images. Shape: [batch,3,300,300]. sources = list() loc = list() conf = list() #應用vgg從第一層到conv4_3 relu層 for k in range(23): x = self.vgg[k](x) s = self.L2Norm(x)#從con4_3處連接L2歸一化層,這點論文中也有 sources.append(s)#保存到sources,下面用到用於連接到定位網絡和置信度網絡 #應用vgg從con4_3到fc7 for k in range(23, len(self.vgg)): x = self.vgg[k](x) sources.append(x) #應用額外層 for k, v in enumerate(self.extras): x = F.relu(v(x), inplace=True) if k % 2 == 1: sources.append(x) #將定位網絡和置信度網絡連接到輸出特徵圖上,並進行維度變換 for (x, l, c) in zip(sources, self.loc, self.conf): loc.append(l(x).permute(0, 2, 3, 1).contiguous()) conf.append(c(x).permute(0, 2, 3, 1).contiguous()) ‘’‘ 使用transpose或permute進行維度變換後,調用contiguous,然後方可使用view對維度進行變形。 網上有兩種說法,一種是維度變換後tensor在內存中不再是連續存儲的,而view操作要求連續存儲,所以需要contiguous,另一種是說維度變換後的變量是之前變量的淺複製,指向同一區域,即view操作會連帶原來的變量一同變形,這是不合法的, ’‘’ #view函數的作用爲重構張量的維度,相當於numpy中resize()的功能 #因爲定位網絡和置信度網絡都是卷積,得到是三維的特徵圖,這裏將各網絡的預測結果展開,並拼接在一起 #o.size(0)是batch_size loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) if self.phase == "test": output = self.detect( loc.view(loc.size(0), -1, 4), # loc preds self.softmax(conf.view(conf.size(0), -1, self.num_classes)), # conf preds self.priors.type(type(x.data)) # default boxes ) else: output = ( loc.view(loc.size(0), -1, 4),#最後將定位結果resize成(batch_size,先驗框的數量,總的先驗框的四個偏移) conf.view(conf.size(0), -1, self.num_classes),#將置信度結果resize成(batch_size,先驗框的數量,類別數) self.priors ) return output #權重加載 def load_weights(self, base_file): other, ext = os.path.splitext(base_file) if ext == '.pkl' or '.pth': print('Loading weights into state dict...') self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) print('Finished!') else: print('Sorry only .pth and .pkl files supported.')
經過上面這些代碼,就構建完成了SSD的網絡結構。然後發現PriorBox和L2Norm這兩個類,這是啥?篇幅所限,下篇再見。