caffe2學習——0207——預加載模型並測試(上)

# coding=utf-8
import numpy as np
from caffe2.proto import caffe2_pb2
import os
from caffe2.python import core, workspace,models
import matplotlib.pyplot as pyplot
import skimage
import skimage.io as io
import skimage.transform
import urllib3

print('required modules imported.')
CAFFE_MODELS = 'D:\Anaconda\envs\pytorch\Lib\site-packages\caffe2\python\models'
print('caffe_model path : {}'.format(CAFFE_MODELS))
IMG_LOCATION = 'F:\cocoDataAugment\data\\night\\night_1.jpg'
MODEL = 'squeezenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.py', 227

codes =  "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes"
print('config set!')

def crop_center(img,cropx,cropy):
    y, x, c = img.shape
    startx = x//2-cropx//2
    starty = y//2-cropy//2
    return img[starty:starty+cropy, startx:startx+cropx]

def rescale(img, input_height, input_width):
    print("Original image shape:" + str(img.shape) + " and remember it should be in H, W, C!")
    print("Model's input shape is:{}x{}".format(input_height, input_width))
    aspect = img.shape[1]/float(img.shape[0])
    print("Orginal aspect ratio: " + str(aspect))
    if(aspect>1):
        # landscape orientation - wide image
        res = int(aspect * input_height)
        imgScaled = skimage.transform.resize(img, (input_height, res))
    if(aspect<1):
        # portrait orientation - tall image
        res = int(input_width/aspect)
        imgScaled = skimage.transform.resize(img, (res, input_width))
    if(aspect == 1):
        imgScaled = skimage.transform.resize(img, (input_height, input_width))
    print("New image shape:" + str(imgScaled.shape) + " in HWC")
    return imgScaled

CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS)
MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0],MODEL[3])
'''
# if not os.path.join(MEAN_FILE):
#     mean = 128
# else:
#     mean = np.load(MEAN_FILE).mean(1).mean(1)
#     mean = mean[:, np.newaxis, np.newaxis]
# 找不到計算mean的py文件,不用了
'''
mean = 128
print('mean is set to : {}'.format(mean))

INPUT_IMAGE_SIZE = MODEL[4]

INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1])  # .../caffe2/python/models + squeezenet + init_net.pb
print('INIT_NET:{}'.format(INIT_NET))
PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2]) # .../caffe2/python/models + squeezenet + predict_net.pb
print('predict_net : {}'.format(PREDICT_NET))

if not os.path.exists(INIT_NET):
    print(INIT_NET, ' not found@')
else:
    print('Found ', INIT_NET, "...Now looking for ", PREDICT_NET)
    if not os.path.exists(PREDICT_NET):
        print(PREDICT_NET, ' not found!')
    else:
        print('all needed files found! loading model in the next block')

img = skimage.img_as_float(skimage.io.imread(IMG_LOCATION)).astype(np.float32)
img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)
img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE)

# hwc > chw , rgb > bgr
img = img.swapaxes(0,2).swapaxes(1,2)
img = img[(2,1,0),:,:]
img = img*255-mean
img = img[np.newaxis,:,:,:].astype(np.float32)
print('img shape nchw:{}'.format(img.shape))


'''
這裏開始新知識,模型加載,預測,提取結果
'''
with open(INIT_NET, encoding='utf-8') as f:
    print(f)
    init_net = f.read()
with open(PREDICT_NET, encoding='utf-8') as f:
    predict_net = f.read()

p = workspace.Predictor(init_net, predict_net)

result = p.run([img])
result = np.asarray(result)
print('result shape :{}'.format(result.shape))

參考https://zhuanlan.zhihu.com/p/34701037

但是由於使用的是python3,無法加載python2的模型。出現

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xf8 in position 1: invalid start byte

因此重新下載、安裝caffe2的python2版本。

此章是前面知識點的小集合

重點在於加載模型(read()),預測(workspace.predictor())以及預測結果的分析(未完成)

後續:

想在conda裏面安caffe2,但系統是win10,再加上各種其他環境的影響,沒有安裝成功。結果的處理,需要以後用實驗室電腦學習。

 

# the rest of this is digging through the results 
results = np.delete(results, 1) 
index = 0
highest = 0
arr = np.empty((0,2), dtype=object) #初始化矩陣
arr[:,0] = int(10) 
arr[:,1:] = float(10)
for i, r in enumerate(results):
    # imagenet index begins with 1!
    i=i+1
    arr = np.append(arr, np.array([[i,r]]), axis=0) # 按種類和概率分組,得到類似【(0,0.8),(1,0.9)】
    if (r > highest):
        highest = r  # 最高識別概率
        index = i  # 識別的類別

# top 3 results
print "Raw top 3 results:", sorted(arr, key=lambda x: x[1], reverse=True)[:3]

# now we can grab the code list
response = urllib2.urlopen(codes)

# and lookup our result from the list
for line in response:   #這裏沒懂,爲什麼要在list裏面驗證我們的result。。img不是自己的嗎
    code, result = line.partition(":")[::2]
    if (code.strip() == str(index)):
        print MODEL[0], "infers that the image contains ", result.strip()[1:-2], "with a ", highest*100, "% probability"

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章