基於Matterport版本的Mask-RCNN訓練自己的數據集

轉載自:https://blog.csdn.net/l297969586/article/details/79140840


本文是轉載的,但可以做個補充:
將model.train()的頭結構訓練階段的epochs=1改爲較大的數,如100,將全網絡訓練的epochs=1改爲200,可以在tensorboard上得到較爲平滑的loss曲線。如果不改,得到的loss曲線是一條直線(只保存了兩次迭代的loss數據)。因爲epochs=100後,總共會迭代100*100=10000次,所以可以將STEPS_PER_EPOCH = 100改小一點,以縮短訓練時間;


一、工具

cuda與cudnn安裝請參考我之前博客:
http://blog.csdn.net/l297969586/article/details/53320706
http://blog.csdn.net/l297969586/article/details/67632608
tensorflow安裝:
http://blog.csdn.net/l297969586/article/details/72820310
ipython-notebook:
http://blog.csdn.net/l297969586/article/details/77851039
Mask-RCNN :

https://github.com/matterport/Mask_RCNN

labelme(標註mask數據集用的):

https://github.com/wkentaro/labelme

二、修改訓練代碼

主要修改train_shapes.ipynb,我個人感覺ipython-notebook不好用,所以我將它轉成.py格式,就是把代碼粘出來。let’s go!
1、註釋%matplotlib inline
2、在ShapesConfig類中,GPU_COUNT = 2,IMAGES_PER_GPU = 1兩個參數自己根據自己電腦配置修改參數,由於該工程用的resnet101爲主幹的網絡,訓練需要大量的顯存支持,我的圖片尺寸爲1280*800的,IMAGES_PER_GPU 設置爲2,在兩個GeForce GTX TITAN X上訓練顯存都會溢出,所以IMAGES_PER_GPU = 1,大佬可忽略。
NUM_CLASSES = 1 + 4爲你數據集的類別數,第一類爲bg,我的是4類,所以爲1+4
IMAGE_MIN_DIM = 800,IMAGE_MAX_DIM = 1280修改爲自己圖片尺寸
RPN_ANCHOR_SCALES = (8 * 6, 16 * 6, 32 * 6, 64 * 6, 128 * 6),根據自己情況設置anchor大小
3、在全局定義一個iter_num=0
△4、重新寫一個訓練類
名字自己起,我的叫

class DrugDataset(utils.Dataset):

添加函數

#得到該圖中有多少個實例(物體)
def get_obj_index(self, image):
        n = np.max(image)
        return n
#解析labelme中得到的yaml文件,從而得到mask每一層對應的實例標籤
def from_yaml_get_class(self,image_id):
        info=self.image_info[image_id]
        with open(info['yaml_path']) as f:
            temp=yaml.load(f.read())
            labels=temp['label_names']
            del labels[0]
        return labels
#重新寫draw_mask
def draw_mask(self, num_obj, mask, image):
        info = self.image_info[image_id]
        for index in range(num_obj):
            for i in range(info['width']):
                for j in range(info['height']):
                    at_pixel = image.getpixel((i, j))
                    if at_pixel == index + 1:
                        mask[j, i, index] =1
        return mask
#重新寫load_shapes,裏面包含自己的自己的類別(我的是box、column、package、fruit四類)
#並在self.image_info信息中添加了path、mask_path 、yaml_path
def load_shapes(self, count, height, width, img_floder, mask_floder, imglist,dataset_root_path):
        """Generate the requested number of synthetic images.
        count: number of images to generate.
        height, width: the size of the generated images.
        """
        # Add classes
        self.add_class("shapes", 1, "box")
        self.add_class("shapes", 2, "column")
        self.add_class("shapes", 3, "package")
        self.add_class("shapes", 4, "fruit")
        for i in range(count):
            filestr = imglist[i].split(".")[0]
            filestr = filestr.split("_")[1]
            mask_path = mask_floder + "/" + filestr + ".png"
            yaml_path=dataset_root_path+"total/rgb_"+filestr+"_json/info.yaml"
            self.add_image("shapes", image_id=i, path=img_floder + "/" + imglist[i],
                           width=width, height=height, mask_path=mask_path,yaml_path=yaml_path)
#重寫load_mask
    def load_mask(self, image_id):
        """Generate instance masks for shapes of the given image ID.
        """
        global iter_num
        info = self.image_info[image_id]
        count = 1  # number of object
        img = Image.open(info['mask_path'])
        num_obj = self.get_obj_index(img)
        mask = np.zeros([info['height'], info['width'], num_obj], dtype=np.uint8)
        mask = self.draw_mask(num_obj, mask, img)
        occlusion = np.logical_not(mask[:, :, -1]).astype(np.uint8)
        for i in range(count - 2, -1, -1):
            mask[:, :, i] = mask[:, :, i] * occlusion
            occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))
        labels=[]
        labels=self.from_yaml_get_class(image_id)
        labels_form=[]
        for i in range(len(labels)):
            if labels[i].find("box")!=-1:
                #print "box"
                labels_form.append("box")
            elif labels[i].find("column")!=-1:
                #print "column"
                labels_form.append("column")
            elif labels[i].find("package")!=-1:
                #print "package"
                labels_form.append("package")
            elif labels[i].find("fruit")!=-1:
                #print "fruit"
                labels_form.append("fruit")
        class_ids = np.array([self.class_names.index(s) for s in labels_form])
        return mask, class_ids.astype(np.int32)

4、代碼主體修改

#基礎設置
dataset_root_path="/home/yangjunfeng/workspace_lj/fg_dateset/"
img_floder = dataset_root_path+"rgb"
mask_floder = dataset_root_path+"mask"
#yaml_floder = dataset_root_path
imglist = listdir(img_floder)
count = len(imglist)
width = 1280
height = 800
#train與val數據集準備
dataset_train = DrugDataset()
dataset_train.load_shapes(count, 800, 1280, img_floder, mask_floder, imglist,dataset_root_path)
dataset_train.prepare()

dataset_val = DrugDataset()
dataset_val.load_shapes(count, 800, 1280, img_floder, mask_floder, imglist,dataset_root_path)
dataset_val.prepare()

註釋掉
model.train(dataset_train,dataset_val,learning_rate=config.LEARNING_RATE/10,epochs=50,layers="all")之後的代碼就好了

三、使用labelme生成mask掩碼數據集

github地址:https://github.com/wkentaro/labelme
安裝方式:

sudo apt-get install python-qt4 pyqt4-dev-tools
sudo pip install labelme

使用,只需在終端輸入:

labelme

我的數據集命名如下
這裏寫圖片描述
Note:在畫掩碼過程中如有多個box、fruit…命名規則爲box1、box2..fruit1、fruit2..。因爲labelme這個標定工具還是不太智能,最後生成的標籤爲一個label.png文件,這個文件只有一通道,在你標註時同一標籤mask會被給予一個標籤位,而mask要求不同的實例要放在不同的層中。最終訓練索要得到的輸入爲一個w*h*n的ndarray,其中n爲該圖片中實例的個數。總而言之,畫mask時就按照上述命名規則就好了,具體的過程已經在上述代碼中實現。如圖:這裏寫圖片描述
此時labelme生成的爲.json文件,需要將json文件轉換爲我們需要的標籤文件,我這裏寫了一個簡單的腳本,不用一個個去轉化了,只需將s1改爲你對應的路徑及圖片前綴名,循環數改爲自己數據集數即可

#!/bin/bash
s1="/media/lj/GSP1RMCPRXV/fg_dateset/json/rgb_"
s2=".json"
for((i=1;i<901;i++))
do 
s3=${i}
labelme_json_to_dataset ${s1}</span><span class="hljs-variable">${s3}${s2}
done

在你圖片目錄下會生成多個rgb_x_json文件夾,每個文件夾中有img.png(原圖),info.yaml,label.png,label_viz.png四個文件,其中需要用的只有info.yaml以及label.png
轉化出來的可視化標籤如圖:
這裏寫圖片描述

四、轉化label.png爲可用格式

labelme生成的掩碼標籤 label.png爲16位存儲,opencv默認讀取8位,需要將16位轉8位
參考:http://blog.csdn.net/l297969586/article/details/79154150

五、訓練

直接運行修改後的py文件即可,訓練中圖片展示:
這裏寫圖片描述

六、結果展示

測試demo也需要改,回頭再寫。。
我只訓練了四個類(box,column,package,friut)
測試圖片未參與訓練,測試結果如下:
這裏寫圖片描述
這裏寫圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章