CAFFE提取特徵並可視化

使用CAFFE( http://caffe.berkeleyvision.org )運行CNN網絡,並提取出特徵,將其存儲成lmdb以供後續使用,亦可以對其可視化。

使用已訓練好的模型進行圖像分類

其實在 http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb 中已經很詳細地介紹了怎麼使用已訓練好的模型對測試圖像進行分類了。由於CAFFE不斷更新,這個頁面的內容和代碼也會更新。以下只記錄當前能運行的主要步驟。

  1. 下載CAFFE,並安裝相應的dependencies。

  2. caffe_root下運行./scripts/download_model_binary.py models/bvlc_reference_caffenet獲得預訓練的CaffeNet。

  3. 在ipython裏(或python,但需要把部分代碼註釋掉)運行以下代碼來加載網絡。

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    import numpy as np
    import matplotlib.pyplot as plt
    %matplotlib inline
    
    
    # Make sure that caffe is on the python path:
    
    caffe_root = '../'  # this file is expected to be in {caffe_root}/examples
    import sys
    sys.path.insert(0, caffe_root + 'python')
    
    import caffe
    
    plt.rcParams['figure.figsize'] = (10, 10)
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    
    import os
    if not os.path.isfile(caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'):
        print("Downloading pre-trained CaffeNet model...")
        !../scripts/download_model_binary.py ../models/bvlc_reference_caffenet
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
  4. 設置網絡爲測試階段,並加載網絡模型prototxt和數據平均值mean_npy。

    • ./models/bvlc_reference_caffenet/deploy.prototxt
    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    • ./python/caffe/imagenet/ilsvrc_2012_mean.npy
    caffe.set_mode_cpu()
    net = caffe.Net(caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt',
                    caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel',
                    caffe.TEST)
    
    
    # input preprocessing: 'data' is the name of the input blob == net.inputs[0]
    
    transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
    transformer.set_transpose('data', (2,0,1))
    transformer.set_mean('data', np.load(caffe_root + 'python/caffe/imagenet/ilsvrc_2012_mean.npy').mean(1).mean(1)) # mean pixel
    transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
    transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
  5. 加載測試圖片,並預測分類結果。

    • ./examples/images/cat.jpg
    
    # set net to batch size of 50
    
    net.blobs['data'].reshape(50,3,227,227)
    
    net.blobs['data'].data[...] = transformer.preprocess('data', caffe.io.load_image(caffe_root + 'examples/images/cat.jpg'))
    out = net.forward()
    print("Predicted class is #{}.".format(out['prob'].argmax()))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
  6. 加載標籤,並輸出top_k。

    • ./data/ilsvrc12/synset_words.txt
    
    # load labels
    
    imagenet_labels_filename = caffe_root + 'data/ilsvrc12/synset_words.txt'
    try:
        labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
    except:
        !../data/ilsvrc12/get_ilsvrc_aux.sh
        labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\t')
    
    
    # sort top k predictions from softmax output
    
    top_k = net.blobs['prob'].data[0].flatten().argsort()[-1:-6:-1]
    print labels[top_k]
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

提取特徵並可視化

接上一章,如果提取特徵之後不作存儲直接可視化的話,可按以下步驟。

  1. 網絡的特徵存儲在net.blobs,參數和bias存儲在net.params,以下代碼輸出每一層的名稱和大小。這裏亦可手動把它們存儲下來。

    [(k, v.data.shape) for k, v in net.blobs.items()]
    [(k, v[0].data.shape) for k, v in net.params.items()]
    • 1
    • 2
  2. 可視化。以下是輔助函數。

    
    # take an array of shape (n, height, width) or (n, height, width, channels)
    
    
    # and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)
    
    def vis_square(data, padsize=1, padval=0):
        data -= data.min()
        data /= data.max()
    
        # force the number of filters to be square
        n = int(np.ceil(np.sqrt(data.shape[0])))
        padding = ((0, n ** 2 - data.shape[0]), (0, padsize), (0, padsize)) + ((0, 0),) * (data.ndim - 3)
        data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))
    
        # tile the filters into an image
        data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
        data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    
        plt.imshow(data)
        plt.show()
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 根據每一層的名稱,選擇需要可視化的層,可以可視化filters(參數)和output(特徵)
    
    # the parameters are a list of [weights, biases]
    
    filters = net.params['conv1'][0].data
    vis_square(filters.transpose(0, 2, 3, 1))
    
    feat = net.blobs['conv1'].data[0, :36]
    vis_square(feat, padval=1)
    
    
    # There are 256 filters, each of which has dimension 5 x 5 x 48. We show only the first 48 filters, with each channel shown separately, so that each filter is a row.
    
    filters = net.params['conv2'][0].data
    vis_square(filters[:48].reshape(48**2, 5, 5))
    
    
    # rectified, only the first 36 of 256 channels
    
    feat = net.blobs['conv2'].data[0, :36]
    vis_square(feat, padval=1)
    
    feat = net.blobs['conv3'].data[0]
    vis_square(feat, padval=0.5)
    
    feat = net.blobs['conv4'].data[0]
    vis_square(feat, padval=0.5)
    
    feat = net.blobs['conv5'].data[0]
    vis_square(feat, padval=0.5)
    
    feat = net.blobs['pool5'].data[0]
    vis_square(feat, padval=1)
    
    feat = net.blobs['fc6'].data[0]
    plt.subplot(2, 1, 1)
    plt.plot(feat.flat)
    plt.subplot(2, 1, 2)
    _ = plt.hist(feat.flat[feat.flat > 0], bins=100)
    
    feat = net.blobs['fc7'].data[0]
    plt.subplot(2, 1, 1)
    plt.plot(feat.flat)
    plt.subplot(2, 1, 2)
    _ = plt.hist(feat.flat[feat.flat > 0], bins=100)
    
    feat = net.blobs['prob'].data[0]
    plt.plot(feat.flat)
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    • 34
    • 35
    • 36
    • 37
    • 38
    • 39
    • 40
    • 41
    • 42
    • 43
    • 44
    • 45
    • 46
    • 47

提取特徵並存儲

CAFFE提供了一個提取特徵的tool,見 http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html

  1. 選擇需要特徵提取的圖像。

    • ./examples/_temp
    mkdir examples/_temp
    find `pwd`/examples/images -type f -exec echo {} \; > examples/_temp/temp.txt
    sed "s/$/ 0/" examples/_temp/temp.txt > examples/_temp/file_list.txt
    • 1
    • 2
    • 3
  2. 跟前面一樣,下載模型以及定義prototxt。

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    • ./examples/_temp/imagenet_val.prototxt
      ./data/ilsvrc12/get_ilsvrc_aux.sh
      cp examples/feature_extraction/imagenet_val.prototxt examples/_temp
      • 1
      • 2
  3. 使用extract_features.bin工具提取特徵,並存儲爲lmdb。運行參數爲extract_features.bin $MODEL $PROTOTXT $LAYER $LMDB_OUTPUT_PATH $BATCHSIZE

    • ./models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel
    • ./examples/_temp/imagenet_val.prototxt
    • ./examples/_temp/features
      ./build/tools/extract_features.bin models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel examples/_temp/imagenet_val.prototxt fc7 examples/_temp/features 10 lmdb
      • 1

使用特徵文件進行可視化

參考 http://www.cnblogs.com/platero/p/3967208.html 和 lmdb的文檔 https://lmdb.readthedocs.org/en/release ,讀取lmdb文件,然後轉換成mat文件,再用matlab調用mat進行可視化。

  1. 安裝CAFFE的python依賴庫,並使用以下兩個輔助文件把lmdb轉換爲mat。

    • ./feat_helper_pb2.py

      
      # Generated by the protocol buffer compiler.  DO NOT EDIT!
      
      
      from google.protobuf import descriptor
      from google.protobuf import message
      from google.protobuf import reflection
      from google.protobuf import descriptor_pb2
      
      # @@protoc_insertion_point(imports)
      
      
      
      DESCRIPTOR = descriptor.FileDescriptor(
        name='datum.proto',
        package='feat_extract',
        serialized_pb='\n\x0b\x64\x61tum.proto\x12\x0c\x66\x65\x61t_extract\"i\n\x05\x44\x61tum\x12\x10\n\x08\x63hannels\x18\x01 \x01(\x05\x12\x0e\n\x06height\x18\x02 \x01(\x05\x12\r\n\x05width\x18\x03 \x01(\x05\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\x12\r\n\x05label\x18\x05 \x01(\x05\x12\x12\n\nfloat_data\x18\x06 \x03(\x02')
      
      
      _DATUM = descriptor.Descriptor(
        name='Datum',
        full_name='feat_extract.Datum',
        filename=None,
        file=DESCRIPTOR,
        containing_type=None,
        fields=[
          descriptor.FieldDescriptor(
            name='channels', full_name='feat_extract.Datum.channels', index=0,
            number=1, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='height', full_name='feat_extract.Datum.height', index=1,
            number=2, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='width', full_name='feat_extract.Datum.width', index=2,
            number=3, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='data', full_name='feat_extract.Datum.data', index=3,
            number=4, type=12, cpp_type=9, label=1,
            has_default_value=False, default_value="",
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='label', full_name='feat_extract.Datum.label', index=4,
            number=5, type=5, cpp_type=1, label=1,
            has_default_value=False, default_value=0,
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
          descriptor.FieldDescriptor(
            name='float_data', full_name='feat_extract.Datum.float_data', index=5,
            number=6, type=2, cpp_type=6, label=3,
            has_default_value=False, default_value=[],
            message_type=None, enum_type=None, containing_type=None,
            is_extension=False, extension_scope=None,
            options=None),
        ],
        extensions=[
        ],
        nested_types=[],
        enum_types=[
        ],
        options=None,
        is_extendable=False,
        extension_ranges=[],
        serialized_start=29,
        serialized_end=134,
      )
      
      DESCRIPTOR.message_types_by_name['Datum'] = _DATUM
      
      class Datum(message.Message):
        __metaclass__ = reflection.GeneratedProtocolMessageType
        DESCRIPTOR = _DATUM
      
        # @@protoc_insertion_point(class_scope:feat_extract.Datum)
      
      
      # @@protoc_insertion_point(module_scope)
      
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
      • 68
      • 69
      • 70
      • 71
      • 72
      • 73
      • 74
      • 75
      • 76
      • 77
      • 78
      • 79
      • 80
      • 81
      • 82
      • 83
      • 84
      • 85
      • 86
      • 87
      • 88
      • 89
      • 90
      • 91
      • 92
    • ./lmdb2mat.py

      import lmdb
      import feat_helper_pb2
      import numpy as np
      import scipy.io as sio
      import time
      
      def main(argv):
          lmdb_name = sys.argv[1]
          print "%s" % sys.argv[1]
          batch_num = int(sys.argv[2]);
          batch_size = int(sys.argv[3]);
          window_num = batch_num*batch_size;
      
          start = time.time()
          if 'db' not in locals().keys():
              db = lmdb.open(lmdb_name)
              txn= db.begin()
              cursor = txn.cursor()
              cursor.iternext()
              datum = feat_helper_pb2.Datum()
      
              keys = []
              values = []
              for key, value in enumerate( cursor.iternext_nodup()):
                  keys.append(key)
                  values.append(cursor.value())
      
          ft = np.zeros((window_num, int(sys.argv[4])))
          for im_idx in range(window_num):
              datum.ParseFromString(values[im_idx])
              ft[im_idx, :] = datum.float_data
      
          print 'time 1: %f' %(time.time() - start)
          sio.savemat(sys.argv[5], {'feats':ft})
          print 'time 2: %f' %(time.time() - start)
          print 'done!'
      
      if __name__ == '__main__':
          import sys
          main(sys.argv)
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
    • 運行bash

      
      #!/usr/bin/env sh
      
      LMDB=./examples/_temp/features_fc7 # lmdb文件路徑
      BATCHNUM=1
      BATCHSIZE=10
      
      # DIM=290400 # feature長度,conv1
      
      
      # DIM=43264 # conv5
      
      DIM=4096
      OUT=./examples/_temp/features_fc7.mat #mat文件保存路徑
      python ./lmdb2mat.py $LMDB $BATCHNUM $BATCHSIZE $DIM $OUT
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
  2. 參考UFLDL裏的display_network函數,對mat文件裏的特徵進行可視化。

    • display_network.m

      function [h, array] = display_network(A, opt_normalize, opt_graycolor, cols, opt_colmajor)
      % This function visualizes filters in matrix A. Each column of A is a
      % filter. We will reshape each column into a square image and visualizes
      % on each cell of the visualization panel. 
      % All other parameters are optional, usually you do not need to worry
      % about it.
      % opt_normalize: whether we need to normalize the filter so that all of
      % them can have similar contrast. Default value is true.
      % opt_graycolor: whether we use gray as the heat map. Default is true.
      % cols: how many columns are there in the display. Default value is the
      % squareroot of the number of columns in A.
      % opt_colmajor: you can switch convention to row major for A. In that
      % case, each row of A is a filter. Default value is false.
      warning off all
      
      if ~exist('opt_normalize', 'var') || isempty(opt_normalize)
          opt_normalize= true;
      end
      
      if ~exist('opt_graycolor', 'var') || isempty(opt_graycolor)
          opt_graycolor= true;
      end
      
      if ~exist('opt_colmajor', 'var') || isempty(opt_colmajor)
          opt_colmajor = false;
      end
      
      % rescale
      A = A - mean(A(:));
      
      if opt_graycolor, colormap(gray); end
      
      % compute rows, cols
      [L M]=size(A);
      sz=sqrt(L);
      buf=1;
      if ~exist('cols', 'var')
          if floor(sqrt(M))^2 ~= M
              n=ceil(sqrt(M));
              while mod(M, n)~=0 && n<1.2*sqrt(M), n=n+1; end
              m=ceil(M/n);
          else
              n=sqrt(M);
              m=n;
          end
      else
          n = cols;
          m = ceil(M/n);
      end
      
      array=-ones(buf+m*(sz+buf),buf+n*(sz+buf));
      
      if ~opt_graycolor
          array = 0.1.* array;
      end
      
      
      if ~opt_colmajor
          k=1;
          for i=1:m
              for j=1:n
                  if k>M, 
                      continue; 
                  end
                  clim=max(abs(A(:,k)));
                  if opt_normalize
                      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
                  else
                      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/max(abs(A(:)));
                  end
                  k=k+1;
              end
          end
      else
          k=1;
          for j=1:n
              for i=1:m
                  if k>M, 
                      continue; 
                  end
                  clim=max(abs(A(:,k)));
                  if opt_normalize
                      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)'/clim;
                  else
                      array(buf+(i-1)*(sz+buf)+(1:sz),buf+(j-1)*(sz+buf)+(1:sz))=reshape(A(:,k),sz,sz)';
                  end
                  k=k+1;
              end
          end
      end
      
      if opt_graycolor
          h=imagesc(array);
      else
          h=imagesc(array,'EraseMode','none',[-1 1]);
      end
      axis image off
      
      drawnow;
      
      warning on all
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15
      • 16
      • 17
      • 18
      • 19
      • 20
      • 21
      • 22
      • 23
      • 24
      • 25
      • 26
      • 27
      • 28
      • 29
      • 30
      • 31
      • 32
      • 33
      • 34
      • 35
      • 36
      • 37
      • 38
      • 39
      • 40
      • 41
      • 42
      • 43
      • 44
      • 45
      • 46
      • 47
      • 48
      • 49
      • 50
      • 51
      • 52
      • 53
      • 54
      • 55
      • 56
      • 57
      • 58
      • 59
      • 60
      • 61
      • 62
      • 63
      • 64
      • 65
      • 66
      • 67
      • 68
      • 69
      • 70
      • 71
      • 72
      • 73
      • 74
      • 75
      • 76
      • 77
      • 78
      • 79
      • 80
      • 81
      • 82
      • 83
      • 84
      • 85
      • 86
      • 87
      • 88
      • 89
      • 90
      • 91
      • 92
      • 93
      • 94
      • 95
      • 96
      • 97
      • 98
      • 99
      • 100
      • 101
    • 在matlab裏運行以下代碼:

      nsample = 2;
      % num_output = 96; % conv1
      % num_output = 256; % conv5
      num_output = 4096; % fc7
      
      load features_fc7.mat
      width = size(feats, 2);
      nmap = width / num_output;
      
      for i = 1 : nsample
          feat = feats(i, :);
          feat = reshape(feat, [nmap num_output]);
          figure('name', sprintf('image #%d', i));
          display_network(feat);
      end
      • 1
      • 2
      • 3
      • 4
      • 5
      • 6
      • 7
      • 8
      • 9
      • 10
      • 11
      • 12
      • 13
      • 14
      • 15

在python中讀取mat文件

在python中,使用scipy.io.loadmat()即可讀取mat文件,返回一個dict()

import scipy.io
matfile = 'features_fc7.mat'
data = scipy.io.loadmat(matfile)
  • 1
  • 2
  • 3

使用自己的網絡

只需把前面列出來的文件與參數修改成自定義的即可。

參考

http://nbviewer.ipython.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb
http://caffe.berkeleyvision.org/gathered/examples/feature_extraction.html
http://www.cnblogs.com/platero/p/3967208.html
https://lmdb.readthedocs.org/en/release/

發佈了16 篇原創文章 · 獲贊 72 · 訪問量 23萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章