blitznet測試及訓練筆記
項目地址 https://github.com/dvornikita/blitznet
環境要求
Python 3.5
Tensorflow >=1.2
Numpy 1.13
Matplotlib 2.0.0
OpenCV 3.2
PIL 4.0
glob
tabulate
progressbar
一、下載項目文件和數據集
下載項目:
git clone https://github.com/dvornikita/blitznet.git
下載數據集並解壓:
Download the data.
cd blitznet/Datasets
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
Extract the data.
tar -xvf VOCtrainval_11-May-2012.tar
tar -xvf VOCtrainval_06-Nov-2007.tar
tar -xvf VOCtest_06-Nov-2007.tar
添加 B. Hariharan et al提供的extra annotations:
python3 set_extra_annotations.py
爲了讓讀入圖片更有效率,還需運行datasets.py.需要datasets.py根據修改bernchmark中的路徑.正確路徑應爲爲:
--benchmark
--benchmark_RELEASE
--README
--dataset
--demo
--BharathICCV2011.pdf
--README
需要先修改main函數中的數據集註釋.deatset會被存在blitznet/Datasets/目錄中 由於沒有下載coco數據集的項目文檔,目前只有voc數據集,需要註釋掉coco數據集,並取消mian函數中的voc數據集的註釋
#from coco_loader import COCOLoader, COCO_CATS
if __name__ == '__main__':
create_voc_dataset('07', 'test')
create_voc_dataset('07', 'trainval')
create_voc_dataset('12', 'train', True, True)
create_voc_dataset('12', 'val', True)
至此可以嘗試demo.py和基於voc數據集的訓練和測試.
二、demo.py的運行
首先下載預先訓練好的模型,在項目文件的後面作者給出了模型文件下載地址: image
demo.py默認的模型和目錄爲 archive/BlitzNet300_COCO+VOC07+12/model.ckpt-1000.data-00000-of-00001
BlitzNet300_COCO+VOC07+12文件的下載地址爲 https://drive.google.com/open?id=0B7XqhdpFpfcIV2hqOWswU01zdlU
下載後解壓,放在archive目錄下即可.
python3 demo.py --run_name=BlitzNet300_COCO+VOC07+12 --x4 --detect --segment --eval_min_conf=0.5 --ckpt=1
作者說輸出在/Demo/output,但是Demo中沒有結果,仔細閱讀demo.py文件,發現程序中給出的目錄是demodemo/output,實際上在該目錄下仍然沒有給出demo結果. 修改
loader = Loader(osp.join(EVAL_DIR, 'demodemo'))
爲
loader = Loader(osp.join(EVAL_DIR, 'Demo'))
再次運行demo.py即可在Demo/output中得到demo.py給出的圖片分割結果.
二、訓練training.py
下載好權重文件, ResNet50和VGG16,解壓後放在Weights_imagenet文件夾下. 將ResNet50目錄下的resnet50_full.ckpt複製到Weights_imagenet目錄下,這裏使用voc07數據集進行訓練
python3 training.py --run_name=BlitzNet300_x4_VOC07 --dataset=voc07+12 --trunk=resnet50 --x4 --batch_size=8 --optimizer=adam --detect --segment --max_iterations=5000 --lr_decay 40000 50000
報錯
Invalid argument: LossTensor is inf or nan : Tensor had NaN values
將學習率設低仍然報錯.
在網上有人提出是因爲訓練集中有很小的boundingbox,需要刪除小尺寸的boundingbox.這裏直接更換訓練集爲voc12即可不報錯.命名如下
python3 training.py --run_name=BlitzNet300_x4_VOC12 --dataset=voc12-train --trunk=resnet50 --x4 --batch_size=8 --optimizer=adam --detect --segment --max_iterations=5000 --lr_decay 40000 50000
三、基於其他數據集的訓練training.py
由於blitznet輸入數據集需要同時有bboex和語義label,cityscapes數據集無法滿足要求.這裏自己製作了數據集標籤,對同一張照片的7個類別,同時標語義標籤和檢測框標籤.
bboex:用labelimg製作即可得到xml文件.
語義分割標籤:通過labelme製作得到json文件,再將json轉換爲標籤圖片.這裏步驟參考 http://note.youdao.com/noteshare?id=f171087a024667f9d4796df0d3a1696a
本文的數據集路徑爲:
fishdata
class #語義分割標籤
img #原始圖片
json #語義分割生成的json文件
xml #物體檢測框標籤
val.txt #驗證集圖片id
train.txt #訓練集圖片id
1、改寫datasets.py得到融合語義標籤和檢測框標籤的二值文件.
參照create_voc_dataset改寫出自己的create_fish_dataset.本文改寫後的代碼入下.
def create_fish_dataset(split):
"""packs a dataset to a protobuf file
Args:
split: split of data, choice=['train', 'val']
"""
loader = FISHLoader(split) #自己的數據集需要傳入的參數只有一個,用來確定是訓練集還是驗證集
#載入FISHLoader類爲loader,它包含多個函數
print("Contains %i files" % len(loader.get_filenames()))
output_file = os.path.join(loader.root, 'fish%s' %split)#修改爲自己的數據集存儲路徑
print (output_file)
#確定了融合後的文件路徑爲/media/yue/DATA/fishdata/fishval
image_placeholder = tf.placeholder(dtype=tf.uint8) #圖片的佔位符
encoded_image = tf.image.encode_png(tf.expand_dims(image_placeholder, 2)) #[2, height, width, channels]的圖片
writer = tf.python_io.TFRecordWriter(output_file)
with tf.Session('') as sess:
for i, f in enumerate(loader.get_filenames()):
#enumerate遍歷序列的元素和下標
# print ('f',f)
path = '%s/img/%s.bmp' % (loader.root, f) #修改爲自己的原始圖片存儲目錄和圖片後綴
# print ('path',path)
with tf.gfile.FastGFile(path, 'rb') as ff:
image_data = ff.read()
gt_bb, segmentation, gt_cats, w, h, diff = loader.read_annotations(f)
# print ("segmentation",segmentation)
gt_bb = normalize_bboxes(gt_bb, w, h)
png_string = sess.run(encoded_image,
feed_dict={image_placeholder: segmentation})
example = _convert_to_example(path, image_data, gt_bb, gt_cats,
diff, png_string, h, w)
if i % 100 == 0:
print("%i files are processed" % i)
writer.write(example.SerializeToString())
writer.close()
print("Done")
由於文件之間的調用,還需要參照voc_loader改寫出讀自己語言數據集的類fish_loader.本文如下:
# -*- coding: utf-8 -*-
import logging
import os
import numpy as np
import xml.etree.ElementTree as ET
from PIL import Image
from paths import DATASETS_ROOT
log = logging.getLogger()
FISH_CATS = ['__background__', 'road', 'line', 'lane', 'car','man', 'bman',
'barrier']
class FISHLoader():
def __init__(self, split):
print (DATASETS_ROOT)
self.root = os.path.join(DATASETS_ROOT, 'fishdata')
self.split = split
assert split in ['train', 'val']
cats = FISH_CATS
# print ('cats', cats)
# cats = VOC_CATS
self.cats_to_ids = dict(map(reversed, enumerate(cats)))
print ('self.cats_to_ids ',self.cats_to_ids )
self.ids_to_cats = dict(enumerate(cats))
self.num_classes = len(cats)
self.categories = cats[1:]
# print ('self.root',self.root)
with open(os.path.join(self.root, str(self.split)+'.txt'), 'r') as f:
self.filenames = f.read().split('\n')[:-1]
print ("Created a loader FISH %s with %i images" % (split, len(self.filenames)))
log.info("Created a loader FISH %s with %i images" % (split, len(self.filenames)))
def load_image(self, name):
im = Image.open('%s/img/%s.jpg' % (self.root, name)).convert('RGB')
im = np.array(im) / 255.0
im = im.astype(np.float32)
return im
#讀入的圖片爲RGB格式,且已經做了歸一化,所有的像素點的三通道值都爲0~1之間
def get_filenames(self):
return self.filenames
def read_annotations(self, name):
bboxes = []
cats = []
tree = ET.parse('%s/xml/%s.xml' % (self.root, name))
root = tree.getroot()
width = int(root.find('size/width').text)
height = int(root.find('size/height').text)
difficulty = []
for obj in root.findall('object'):
cat = self.cats_to_ids[obj.find('name').text] #cat物體的檢測類別
difficult = (int(obj.find('Difficult').text) != 0)
difficulty.append(difficult)
cats.append(cat)
bbox_tag = obj.find('bndbox')
x = int(bbox_tag.find('xmin').text)
y = int(bbox_tag.find('ymin').text)
w = int(bbox_tag.find('xmax').text)-x
h = int(bbox_tag.find('ymax').text)-y
bboxes.append((x, y, w, h))
gt_cats = np.array(cats)
gt_bboxes = np.array(bboxes).reshape((len(bboxes), 4))
#print len(gt_bboxes)
difficulty = np.array(difficulty)
seg_gt = self.read_segmentations(name, height, width)
output = gt_bboxes, seg_gt, gt_cats, width, height, difficulty
return output
#gt_cat檢測的物體的名稱,#seg_gt是對應圖片的分割掩碼,gt_bboxes圖片中所以物體的檢測框
def read_segmentations(self, name, height, width):
seg_folder = self.root + '/class/'
seg_file = seg_folder + name + '.png'
# print (seg_file)
if os.path.exists(seg_file):
# print ('seg_file',seg_file)
seg_map = Image.open(seg_file)
segmentation = np.array(seg_map, dtype=np.uint8)
else:
# if there is no segmentation for a particular image we fill the mask
# with zeros to keep the same amount of tensors but don't learn from it
segmentation = np.zeros([height, width], dtype=np.uint8) + 255
return segmentation
將path中DATASETS_ROOT改爲fishdata的存放路徑,在darasets,py的主函數裏添加
create_fish_dataset('train')
create_fish_dataset('val')
,即可得到二值文件fishtrain,fishval.
至此完成檢測框與語義分割標籤的融合.
2、改寫training.py訓練自己的魚眼數據
1、在main()函數中添加fishdata的判斷,本文添加如下
if args.dataset == 'fishdata':
dataset = get_dataset('fishtrain','fishval')
2、修改datasets.py中的splits_to_sizes中添加關於fishdata的訓練集和驗證集大小,本文如下
'fishtrain':1810,
'fishval':134,
3、training.py還需做如下修改. 將
git_diff = subprocess.check_output('git diff --no-color'.split()).decode('ascii')
中的decode(‘ascii’)去掉.如下
git_diff = subprocess.check_output('git diff --no-color'.split())
在未去掉時報錯: UnicodeDecodeError: ‘ascii’ codec can’t decode byte 0xe6 in position 1878: ordinal not in range(128). 在網上查找很多方法都未果,去掉後面的解碼方式就可以運行通過. 4、檢查數據集目錄 本文按照上述一路下來,在運行命令時,報錯
tensorflow.python.framework.errors_impl.NotFoundError: /media/yue/DATA/fishtrain; No such file or directory
直接把第1步得到的二值文件fishtrain,fishval移動到DATASETS_ROOT的目錄下,即移動到他倆的上一層目錄下.
3、運行training.py
本文的運行命令如下.
python3 training.py --run_name=FISH --dataset=fishdata --trunk=resnet50 --x4 --batch_size=8 --optimizer=adam --detect --segment --max_iterations=5000 --lr_decay 40000 50000