Caffe中使用Python腳本在cifar10數據集上測試模型

     除了使用caffe的shell命令外,還可以調用caffe的python接口來測試數據集。本文測試的模型是squeezenet,測試的數據集是cifar10,cifar10_test_lmdb中有10000張圖片。大小是3X32X32。

import numpy as np
import lmdb
import sys
import time
import pdb

sys.path.insert(0, './python') # just to import caffe python library

import caffe
import caffe.proto
import caffe.io

from caffe.proto import caffe_pb2
from caffe.io import blobproto_to_array

MODEL_FILE = 'examples/SqueezeNet/SqueezeNet_v1.1/deploy.prototxt'#注意這是deploy.prototxt文件
PRETRAINED = 'examples/SqueezeNet/SqueezeNet_v1.8_iter_300000.caffemodel'

MEAN_FILE = 'examples/cifar10/mean.binaryproto'#使用的是.binaryproto文件不是.npy
META_FILE = 'examples/cifar10/batches.meta.txt'

def load_labels(meta_file=META_FILE):
    with open(meta_file) as mfile:
        labelNames = [x.strip('\n') for x in mfile.readlines()]
    labelNames.remove('')
    return labelNames

def load_mean(mean_file=MEAN_FILE):
    blob = caffe_pb2.BlobProto()
    data = open(mean_file, "rb").read()
    blob.ParseFromString(data)
    nparray = blobproto_to_array(blob)
   # print (nparray)
    return nparray[0]

caffe.set_mode_gpu()

net = caffe.Net(MODEL_FILE,
                PRETRAINED,
                caffe.TEST)

transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})#預處理
#transformer.set_transpose('data', (2,0,1))
#transformer.set_mean('data', load_mean()) # mean pixel
#transformer.set_raw_scale('data', 255)  # the reference model operates on images in [0,255] range instead of [0,1]
#transformer.set_channel_swap('data', (2,1,0))  # the reference model has channels in BGR order instead of RGB
#print (transformer)

# read image from database
lmdb_env = lmdb.open('examples/cifar10/cifar10_test_lmdb')#讀取的是lmdb格式的。
#lmdb_env = lmdb.open('examples/cifar10/cifar10_train_lmdb')

lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()

meanImage = load_mean()

count = 0
labels_set = set()
lbl_names = load_labels()
numLabels = len(lbl_names)

confusionMatrix = np.zeros((numLabels,numLabels))
for key, value in lmdb_cursor:
    datum = caffe.proto.caffe_pb2.Datum()
    datum.ParseFromString(value)
    label = int(datum.label)
   # print (label)
    image = caffe.io.datum_to_array(datum)
    image = image.astype(np.uint8)
   # print (image)
    normImg = image-meanImage
   # normImg = image
    #pdb.set_trace()
   # print (normImg.shape)格式是3x32x32

    #out = net.forward_all(data=np.asarray([image]))
    out = net.forward_all(data=np.asarray([normImg]))
   # out = net.forward_all(data=np.asarray([transformer.preprocess('data', normImg)]))
   # out = net.forward()
   # print (out)
    prediction = int(out['prob'][0].argmax(axis=0))
    confusionMatrix[label,prediction] += 1 
    if confusionMatrix.sum() % 1000 == 0:
       print ("processed %i" % confusionMatrix.sum())

#normalize to percent.  The next few lines could be replaced by a simple divide by 10 for CIFAR
numEachClass = confusionMatrix.sum(1)  #should b 1000 in CIFAR
sumCorrect = 0
for i in range(confusionMatrix.shape[0]):
  confusionMatrix[i,:] *= 100 / numEachClass[i]
  sumCorrect += confusionMatrix[i,i]

print ("Total Accuracy: %.1f%% \n" % (sumCorrect/confusionMatrix.shape[0]))
print (lbl_names)
for i in range(confusionMatrix.shape[0]):
    print (confusionMatrix[i,:], lbl_names[i])


只需要修改相應的文件路徑即可使用。輸出結果如下,這個矩陣橫向是label,縱向是prediction。

總結幾個坑:

1.在caffe根目錄下運行,由於文件中的路徑設置。這是Python3的代碼。

2.代碼中使用的是deploy.prototxt文件,這個文件是caffe中專門用於測試的文件,和train_val.prototxt文件不一樣,這個只用於一次測試。deploy.prototxt文件中的輸入數據的格式改成了dim:1,dim:3,dim:32,dim:32。

3.最重要也是最坑爹的是,caffe是根據每一層的name來匹配weights,所以你在deploy.prototxt中的每一層的name要和你用train_val.prototxt的模型的每一層名字一樣。特別是在微調時,我們需要把公版的模型改成自己的,就要把最後一層的名字和輸出數修改,好讓它只訓練最後一層的參數。因此,千萬不要忘了在deploy.prototxt中也要修改。不然,會出現輸出的每一類的得分全是0.1,分類出現問題,導致精度只有0.1。

4.有的代碼中說要將均值文件.binaryproto轉爲.npy文件,其實是因爲python中只能讀取nparray格式數據,本代碼中有定義函數來轉換。

5.可以看到其他的代碼中,會有許多預處理,是因爲使用caffe.io.load_image()讀進來的是RGB格式和0~1(float),所以在進行識別之前,要在transformer中設置transformer.set_raw_scale('data',255)(縮放至0~255)以transformer.set_channel_swap('data',(2,1,0)(將RGB變換到BGR)還有將數據格式從32x32x3變成3x32x32。但是本代碼中將這些預處理的操作給註釋掉是因爲讀取的文件時lmdb格式數據。除了均值以外其他操作都可省略。當然cifar10的均值也可以直接用:

參考文獻:

pycaffe做識別時通道轉換問題:https://blog.csdn.net/jacke121/article/details/79247063

https://github.com/bobf34/caffe_examples

www.cnblogs.com/Allen-rg/p/5834551.html

https://www.jb51.net/article/168471.htm

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章