在caffe下進行手寫數字的預測

本文主要講解如何用caffe已訓練好的模型進行手寫數字的預測。準備:1、數字圖片。任意尺寸,可以是彩色,也可以是灰度的。如果是灰度的,把下面代碼裏的彩色轉灰度的代碼去掉就行。2、已訓練好的模型。例如:lenet_iter_10000.caffemodel

預測的步驟比較簡單,主要是讀取圖片、構建網絡、前向傳播。其中最關鍵的就是讀取圖片了。一開始我使用的caffe進行讀取圖片:caffe.io.load_image(),但是預測結果都是錯的,目前還沒有發現是什麼原因引起的。改用opencv進行讀取就對了。代碼是根據github上一個例程修改的,這是源代碼的鏈接:caffe-mnist-test

    下面直接給出代碼:
    #predict.py
    import os
    import caffe
    import numpy as np
    import cv2


    caffe_root = caffe_dir(替換自己的Caffe根目錄)

    MODEL_FILE = caffe_root+'examples/mnist/lenet.prototxt'
    PRETRAINED = caffe_root+'examples/mnist/lenet_iter_10000.caffemodel'
    IMAGE_FILE = caffe_root+'examples/mnist/test/8.bmp'

    img = cv2.imread(IMAGE_FILE)
    if img.shape != [28, 28]:
            img2 = cv2.resize(img, (28, 28))
            img = img2.reshape(28, 28, -1)
    else:
            img = img.reshape(28, 28, -1)
    input_image = 1.0 - img / 255.0

    print input_image.shape
    print input_image
    input_image = np.dot(input_image, np.transpose([0.3, 0.59, 0.11]))(如果原圖是灰度圖片,則不需要這步)
    input_image = input_image[:,:,np.newaxis](如果讀取的數據尺寸就是[28,28,1],則不需要此步驟)

    #img = caffe.io.load_image(IMAGE_FILE, color=False)
    net = caffe.Net(MODEL_FILE, PRETRAINED, caffe.TEST)
    caffe.set_mode_cpu()
    res = net.forward_all(data = np.asarray([input_image.transpose(2, 0, 1)]))
    prediction = res['prob'][0]
    print 'predicted class:', prediction.argmax()

    執行以下命令進行預測:
        cd caffe_root
        python predict.py

    路徑都對的話,就會輸出正確的預測分類了。

    很慶幸,找到問題根源了。原來caffe.io.load_image()讀取到的圖片數據0代表黑色,255(1)代表白色,而進行預測的時候,必須反過來,0表示白色,1表示黑色,前面代碼中的1-img/255正是這個原因。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章