faster rcnn可視化(修改demo.py保存網絡中間結果)

轉載自:http://blog.csdn.net/u010668907/article/details/51439503


faster rcnn用Python版本https://github.com/rbgirshick/py-faster-rcnn

以demo.py中默認網絡VGG16.

原本demo.py地址https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py

圖有點多,貼一個圖的本分結果出來:


上圖是原圖,下面第一張是網絡中命名爲“conv1_1”的結果圖;第二張是命名爲“rpn_cls_prob_reshape”的結果圖;第三張是“rpnoutput”的結果圖

看一下我修改後的代碼:

  1. #!/usr/bin/env python  
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """ 
  11. Demo script showing detections in sample images. 
  12.  
  13. See README.md for installation instructions before running. 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.config import cfg  
  18. from fast_rcnn.test import im_detect  
  19. from fast_rcnn.nms_wrapper import nms  
  20. from utils.timer import Timer  
  21. import matplotlib.pyplot as plt  
  22. import numpy as np  
  23. import scipy.io as sio  
  24. import caffe, os, sys, cv2  
  25. import argparse  
  26. import math  
  27.   
  28. CLASSES = ('__background__',  
  29.            'aeroplane''bicycle''bird''boat',  
  30.            'bottle''bus''car''cat''chair',  
  31.            'cow''diningtable''dog''horse',  
  32.            'motorbike''person''pottedplant',  
  33.            'sheep''sofa''train''tvmonitor')  
  34.   
  35. NETS = {'vgg16': ('VGG16',  
  36.                   'VGG16_faster_rcnn_final.caffemodel'),  
  37.         'zf': ('ZF',  
  38.                   'ZF_faster_rcnn_final.caffemodel')}  
  39.   
  40.   
  41. def vis_detections(im, class_name, dets, thresh=0.5):  
  42.     """Draw detected bounding boxes."""  
  43.     inds = np.where(dets[:, -1] >= thresh)[0]  
  44.     if len(inds) == 0:  
  45.         return  
  46.   
  47.     im = im[:, :, (210)]  
  48.     fig, ax = plt.subplots(figsize=(1212))  
  49.     ax.imshow(im, aspect='equal')  
  50.     for i in inds:  
  51.         bbox = dets[i, :4]  
  52.         score = dets[i, -1]  
  53.   
  54.         ax.add_patch(  
  55.             plt.Rectangle((bbox[0], bbox[1]),  
  56.                           bbox[2] - bbox[0],  
  57.                           bbox[3] - bbox[1], fill=False,  
  58.                           edgecolor='red', linewidth=3.5)  
  59.             )  
  60.         ax.text(bbox[0], bbox[1] - 2,  
  61.                 '{:s} {:.3f}'.format(class_name, score),  
  62.                 bbox=dict(facecolor='blue', alpha=0.5),  
  63.                 fontsize=14, color='white')  
  64.   
  65.     ax.set_title(('{} detections with '  
  66.                   'p({} | box) >= {:.1f}').format(class_name, class_name,  
  67.                                                   thresh),  
  68.                   fontsize=14)  
  69.     plt.axis('off')  
  70.     plt.tight_layout()  
  71.     #plt.draw()  
  72. def save_feature_picture(data, name, image_name=None, padsize = 1, padval = 1):  
  73.     data = data[0]  
  74.     #print "data.shape1: ", data.shape  
  75.     n = int(np.ceil(np.sqrt(data.shape[0])))  
  76.     padding = ((0, n ** 2 - data.shape[0]), (00), (0, padsize)) + ((00),) * (data.ndim - 3)  
  77.     #print "padding: ", padding  
  78.     data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))  
  79.     #print "data.shape2: ", data.shape  
  80.       
  81.     data = data.reshape((n, n) + data.shape[1:]).transpose((0213) + tuple(range(4, data.ndim + 1)))  
  82.     #print "data.shape3: ", data.shape, n  
  83.     data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])  
  84.     #print "data.shape4: ", data.shape  
  85.     plt.figure()  
  86.     plt.imshow(data,cmap='gray')  
  87.     plt.axis('off')  
  88.     #plt.show()  
  89.     if image_name == None:  
  90.         img_path = './data/feature_picture/'   
  91.     else:  
  92.         img_path = './data/feature_picture/' + image_name + "/"  
  93.         check_file(img_path)  
  94.     plt.savefig(img_path + name + ".jpg", dpi = 400, bbox_inches = "tight")  
  95. def check_file(path):  
  96.     if not os.path.exists(path):  
  97.         os.mkdir(path)  
  98. def demo(net, image_name):  
  99.     """Detect object classes in an image using pre-computed object proposals."""  
  100.   
  101.     # Load the demo image  
  102.     im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)  
  103.     im = cv2.imread(im_file)  
  104.   
  105.     # Detect all object classes and regress object bounds  
  106.     timer = Timer()  
  107.     timer.tic()  
  108.     scores, boxes = im_detect(net, im)  
  109.     for k, v in net.blobs.items():  
  110.         if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1:  
  111.             save_feature_picture(v.data, k.replace("/", ""), image_name)#net.blobs["conv1_1"].data, "conv1_1")   
  112.     timer.toc()  
  113.     print ('Detection took {:.3f}s for '  
  114.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])  
  115.   
  116.     # Visualize detections for each class  
  117.     CONF_THRESH = 0.8  
  118.     NMS_THRESH = 0.3  
  119.     for cls_ind, cls in enumerate(CLASSES[1:]):  
  120.         cls_ind += 1 # because we skipped background  
  121.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
  122.         cls_scores = scores[:, cls_ind]  
  123.         dets = np.hstack((cls_boxes,  
  124.                           cls_scores[:, np.newaxis])).astype(np.float32)  
  125.         keep = nms(dets, NMS_THRESH)  
  126.         dets = dets[keep, :]  
  127.         vis_detections(im, cls, dets, thresh=CONF_THRESH)  
  128.   
  129. def parse_args():  
  130.     """Parse input arguments."""  
  131.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')  
  132.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',  
  133.                         default=0, type=int)  
  134.     parser.add_argument('--cpu', dest='cpu_mode',  
  135.                         help='Use CPU mode (overrides --gpu)',  
  136.                         action='store_true')  
  137.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',  
  138.                         choices=NETS.keys(), default='vgg16')  
  139.   
  140.     args = parser.parse_args()  
  141.   
  142.     return args  
  143.   
  144. def print_param(net):  
  145.     for k, v in net.blobs.items():  
  146.     print (k, v.data.shape)  
  147.     print ""  
  148.     for k, v in net.params.items():  
  149.     print (k, v[0].data.shape)    
  150.   
  151. if __name__ == '__main__':  
  152.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals  
  153.   
  154.     args = parse_args()  
  155.   
  156.     prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],  
  157.                             'faster_rcnn_alt_opt''faster_rcnn_test.pt')  
  158.     #print "prototxt: ", prototxt  
  159.     caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',  
  160.                               NETS[args.demo_net][1])  
  161.   
  162.     if not os.path.isfile(caffemodel):  
  163.         raise IOError(('{:s} not found.\nDid you run ./data/script/'  
  164.                        'fetch_faster_rcnn_models.sh?').format(caffemodel))  
  165.   
  166.     if args.cpu_mode:  
  167.         caffe.set_mode_cpu()  
  168.     else:  
  169.         caffe.set_mode_gpu()  
  170.         caffe.set_device(args.gpu_id)  
  171.         cfg.GPU_ID = args.gpu_id  
  172.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)  
  173.       
  174.     #print_param(net)  
  175.   
  176.     print '\n\nLoaded network {:s}'.format(caffemodel)  
  177.   
  178.     # Warmup on a dummy image  
  179.     im = 128 * np.ones((3005003), dtype=np.uint8)  
  180.     for i in xrange(2):  
  181.         _, _= im_detect(net, im)  
  182.   
  183.     im_names = ['000456.jpg''000542.jpg''001150.jpg',  
  184.                 '001763.jpg''004545.jpg']  
  185.     for im_name in im_names:  
  186.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  187.         print 'Demo for data/demo/{}'.format(im_name)  
  188.         demo(net, im_name)  
  189.   
  190.     #plt.show()  
#!/usr/bin/env python

# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse
import math

CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')

NETS = {'vgg16': ('VGG16',
                  'VGG16_faster_rcnn_final.caffemodel'),
        'zf': ('ZF',
                  'ZF_faster_rcnn_final.caffemodel')}


def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    #plt.draw()
def save_feature_picture(data, name, image_name=None, padsize = 1, padval = 1):
    data = data[0]
    #print "data.shape1: ", data.shape
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = ((0, n ** 2 - data.shape[0]), (0, 0), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
    #print "padding: ", padding
    data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))
    #print "data.shape2: ", data.shape
    
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    #print "data.shape3: ", data.shape, n
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    #print "data.shape4: ", data.shape
    plt.figure()
    plt.imshow(data,cmap='gray')
    plt.axis('off')
    #plt.show()
    if image_name == None:
        img_path = './data/feature_picture/' 
    else:
        img_path = './data/feature_picture/' + image_name + "/"
        check_file(img_path)
    plt.savefig(img_path + name + ".jpg", dpi = 400, bbox_inches = "tight")
def check_file(path):
    if not os.path.exists(path):
        os.mkdir(path)
def demo(net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im)
    for k, v in net.blobs.items():
        if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1:
            save_feature_picture(v.data, k.replace("/", ""), image_name)#net.blobs["conv1_1"].data, "conv1_1") 
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1 # because we skipped background
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        vis_detections(im, cls, dets, thresh=CONF_THRESH)

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Faster R-CNN demo')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--cpu', dest='cpu_mode',
                        help='Use CPU mode (overrides --gpu)',
                        action='store_true')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        choices=NETS.keys(), default='vgg16')

    args = parser.parse_args()

    return args

def print_param(net):
    for k, v in net.blobs.items():
	print (k, v.data.shape)
    print ""
    for k, v in net.params.items():
	print (k, v[0].data.shape)  

if __name__ == '__main__':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals

    args = parse_args()

    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                            'faster_rcnn_alt_opt', 'faster_rcnn_test.pt')
    #print "prototxt: ", prototxt
    caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_faster_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
        cfg.GPU_ID = args.gpu_id
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)
    
    #print_param(net)

    print '\n\nLoaded network {:s}'.format(caffemodel)

    # Warmup on a dummy image
    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _= im_detect(net, im)

    im_names = ['000456.jpg', '000542.jpg', '001150.jpg',
                '001763.jpg', '004545.jpg']
    for im_name in im_names:
        print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
        print 'Demo for data/demo/{}'.format(im_name)
        demo(net, im_name)

    #plt.show()
1.在data下手動創建“feature_picture”文件夾就可以替換原來的demo使用了。

2.上面代碼主要添加方法是:save_feature_picture,它會對網絡測試的某些階段的數據處理然後保存。

3.某些階段是因爲:if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1這行代碼(110行),保證網絡層name有這三個詞的纔會被保存,因爲其他層無法用圖片

保存,如全連接(參數已經是二維的了)等層。

4.放開174行print_param(net)的註釋,就可以看到網絡參數的輸出。

5.執行的最終結果 是在data/feature_picture產生以圖片名字爲文件夾名字的文件夾,文件夾下有以網絡每層name爲名字的圖片。

6.另外部分網絡的層name中有非法字符不能作爲圖片名字,我在代碼的111行只是把‘字符/’剔除掉了,所以建議網絡名字不要又其他字符。

圖片下載和代碼下載方式:

  1. git clone https://github.com/meihuakaile/faster-rcnn.git  

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章