pytorch 從頭開始YOLOV3(一):COCO數據集準備和讀取

YOLOV3是工業上可以用的兼顧速度和準確率的一個深度學習目標檢測模型,本系列文章將詳細解釋該模型的構成和實現,本文代碼借鑑:https://github.com/eriklindernoren/PyTorch-YOLOv3

YOLOv3: An Incremental Improvement:https://pjreddie.com/media/files/papers/YOLOv3.pdf

原理在該篇博客就寫的很詳細了,這裏就不贅述了:https://blog.csdn.net/leviopku/article/details/82660381

自己對裏面的內容進行優化修改,具體github地址等我完全完稿再放上來,儘量在這兩週完成這個系列.

1.文件組織架構

├── checkpoints/  #模型
├── data/  #數據
│   ├── get_coco_dataset.sh
│   ├── coco.names
├── utils/  #使用的函數
│   ├── __init__.py
│   ├── datasets.py
│   └── utils.py
├── config/  #配置文件
├── output/  #輸出預測
├── weights/ #模型權重
├── README.md 
├── models.py #模型
├── train.py  #訓練
├── test.py   #測試
├── detect.py #快速使用模型
└── requirements.txt  #環境

2.下載數據集

get_coco_dataset.sh 文件: 下載數據集並且製作訓練集絕對路徑文本

#!/bin/bash

# CREDIT: https://github.com/pjreddie/darknet/tree/master/scripts/get_coco_dataset.sh

# Clone COCO API
git clone https://github.com/pdollar/coco
cd coco

mkdir images
cd images

# Download Images
wget -c https://pjreddie.com/media/files/train2014.zip
wget -c https://pjreddie.com/media/files/val2014.zip

# Unzip
unzip -q train2014.zip
unzip -q val2014.zip

cd ..

# Download COCO Metadata
wget -c https://pjreddie.com/media/files/instances_train-val2014.zip
wget -c https://pjreddie.com/media/files/coco/5k.part
wget -c https://pjreddie.com/media/files/coco/trainvalno5k.part
wget -c https://pjreddie.com/media/files/coco/labels.tgz
tar xzf labels.tgz
unzip -q instances_train-val2014.zip

# Set Up Image Lists
paste <(awk "{print \"$PWD\"}" <5k.part) 5k.part | tr -d '\t' > 5k.txt
paste <(awk "{print \"$PWD\"}" <trainvalno5k.part) trainvalno5k.part | tr -d '\t' > trainvalno5k.txt

3.配置文件

config.py   可以先不看這個,這個是後面需要的路徑名和一些超參數,這不是我們關注的重點,但是需要這個.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
from pprint import pprint


class Config:
    epochs = 20
    batch_size = 1
    imge_folder = 'data/samples'
    classes = 80

    # 配置文件地址
    model_config_path = 'config/yolov3.cfg'
    data_config_path = 'config/coco.data'
    weight_path = 'weights/yolov3.weights'
    class_path = 'data/coco.names'

    # 超參數
    conf_threshold = 0.8
    nms_threshold = 0.4
    img_size = 416
    checkpoint_interval = 1
    use_cuda = True
    momentum = 0.9
    decay = 0.0005
    learning_rate = 0.001
    burn_in = 1000

    checkpoint_dir = 'checkpoints'
    train = 'data/coco/trainvalno5k.txt'
    valid = 'data/coco/5k.txt'
    names = 'data/coco.names'
    backup = 'backup/'
    eval = 'coco'
    # 判斷終端輸入是否正確

    def _parse(self, kwargs):
        state_dict = self._state_dict()
        for k, v in kwargs.items():
            if k not in state_dict:
                raise ValueError('UnKnown Option: "--%s"' % k)
            setattr(self, k, v)

        print('======user config========')
        pprint(self._state_dict())
        print('==========end============')

    # 終端輸入替換默認配置
    def _state_dict(self):
        return {k: getattr(self, k) for k, _ in Config.__dict__.items()
                if not k.startswith('_')}


opt = Config()

 

4.讀數據

在主函數中加pytorch數據加載函數

traindata = Datasets(train_path)
dataloader = torch.utils.data.DataLoader(
     traindata, batch_size=opt.batch_size, shuffle=False)

其中數據集Datasets函數爲

#!/usr/bin/ebv pyhton
# -*- coding:utf-8 -*-

from __future__ import division

import os
import numpy as np
import torch
import sys

from torch.utils.data import Dataset
from skimage.transform import resize
import cv2


class Datasets(Dataset):
    def __init__(self, list_path, img_size=416):
        with open(list_path, 'r') as file:
            # readline() 讀一行, readlines()讀全部並返回list,這裏返回的是圖像絕對地址
            self.img_files = file.readlines()
        self.label_files = [path.replace('images', 'labels').replace(
            '.png', '.txt').replace('.jpg', '.txt') for path in self.img_files]
        self.img_shape = (img_size, img_size)
        self.max_objects = 50  # 最大物體數量

    def __getitem__(self, index):
        img_path = self.img_files[index % len(self.img_files)].rstrip()
        # Python中有三個去除頭尾字符、空白符的函數,它們依次爲:
        # strip: 用來去除頭尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、製表符、空格)
        # lstrip:用來去除開頭字符、空白符(包括\n、\r、\t、' ',即:換行、回車、製表符、空格)
        # rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、製表符、空格)
        img = np.array(cv2.imread(img_path))

        # 把不是彩色圖像的用下一張圖像替換
        while len(img.shape) != 3:
            index += 1
            img_path = self.imge_files[index % len(self.img_files)].rstrip()
            img = np.array(cv2.imread(img_path))

        h, w, _ = img.shape

        # 填充圖片至正方形
        dim_diff = np.abs(h - w)
        pad1, pad2 = dim_diff // 2, dim_diff - dim_diff // 2
        pad = ((pad1, pad2), (0, 0), (0, 0)) if h <= w else (
            (0, 0), (pad1, pad2), (0, 0))
        # np.pad函數見 https://blog.csdn.net/qq_36332685/article/details/78803622
        # 這裏pad兩位指的是,第幾軸,頭尾增加pad1,pad2位數值
        input_img = np.pad(img, pad, 'constant', constant_values=128) / 255.

        # 這裏注意的是,圖片填充和resize(),標籤也需要做相應操作,不然對不上
        padded_h, padded_w, _ = input_img.shape
        # cv2.resize()輸出默認是3通道
        input_img = cv2.resize(input_img, (416, 416))

        input_img = np.transpose(input_img, (2, 0, 1))
        input_img = torch.from_numpy(input_img).float()

        # 製作標籤
        label_path = self.label_files[index % len(self.img_files)].rstrip()
        lables = None
        if os.path.exists(label_path):
            # 五位標籤,(類別,x,y,w,h) x,y爲矩陣中心點
            labels = np.loadtxt(label_path).reshape(-1, 5)
            x1 = w * (labels[:, 1] - labels[:, 3] / 2)
            y1 = h * (labels[:, 2] - labels[:, 4] / 2)
            x2 = w * (labels[:, 1] + labels[:, 3] / 2)
            y2 = h * (labels[:, 2] + labels[:, 4] / 2)
            # 邊界填充
            x1 += pad[1][0]
            y1 += pad[0][0]
            x2 += pad[1][0]
            y2 += pad[0][0]
            # resize
            labels[:, 1] = ((x1 + x2) / 2) / padded_w
            labels[:, 2] = ((y1 + y2) / 2) / padded_h
            labels[:, 3] *= w / padded_w
            labels[:, 4] *= h / padded_h

        # 初始化標籤結果
        filled_labels = np.zeros((self.max_objects, 5))
        # 存儲標籤,如果沒有就爲零,超過50就捨棄
        if labels is not None:
            filled_labels[range(len(labels))[:self.max_objects]
                          ] = labels[:self.max_objects]
        filled_labels = torch.from_numpy(filled_labels)
        return img_path, input_img, filled_labels

    def __len__(self):
        return len(self.img_files)

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章